Spaces:
Sleeping
Sleeping
| import cv2 as cv | |
| import numpy as np | |
| import gradio as gr | |
| from vittrack import VitTrack | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import tempfile | |
| # Download ONNX model at startup | |
| MODEL_PATH = hf_hub_download( | |
| repo_id="opencv/object_tracking_vittrack", | |
| filename="object_tracking_vittrack_2023sep.onnx" | |
| ) | |
| backend_id = cv.dnn.DNN_BACKEND_OPENCV | |
| target_id = cv.dnn.DNN_TARGET_CPU | |
| car_on_road_video = "examples/car.mp4" | |
| car_in_desert_video = "examples/desert_car.mp4" | |
| # Global state | |
| state = { | |
| "points": [], | |
| "bbox": None, | |
| "video_path": None, | |
| "first_frame": None | |
| } | |
| #Example bounding boxes | |
| bbox_dict = { | |
| "car.mp4": "(152, 356, 332, 104)", | |
| "desert_car.mp4": "(758, 452, 119, 65)", | |
| } | |
| def load_first_frame(video_path): | |
| """Load video, grab first frame, reset state.""" | |
| state["video_path"] = video_path | |
| cap = cv.VideoCapture(video_path) | |
| has_frame, frame = cap.read() | |
| cap.release() | |
| if not has_frame: | |
| return None | |
| state["first_frame"] = frame.copy() | |
| return cv.cvtColor(frame, cv.COLOR_BGR2RGB) | |
| def select_point(img, evt: gr.SelectData): | |
| """Accumulate up to 4 clicks, draw polygon + bounding box.""" | |
| if state["first_frame"] is None: | |
| return None | |
| x, y = int(evt.index[0]), int(evt.index[1]) | |
| if len(state["points"]) < 4: | |
| state["points"].append((x, y)) | |
| vis = state["first_frame"].copy() | |
| # draw each point | |
| for pt in state["points"]: | |
| cv.circle(vis, pt, 5, (0, 255, 0), -1) | |
| # draw connecting polygon | |
| if len(state["points"]) > 1: | |
| pts = np.array(state["points"], dtype=np.int32) | |
| cv.polylines(vis, [pts], isClosed=False, color=(255, 255, 0), thickness=2) | |
| # once we have exactly 4, compute & draw bounding rect | |
| if len(state["points"]) == 4: | |
| pts = np.array(state["points"], dtype=np.int32) | |
| x0, y0, w, h = cv.boundingRect(pts) | |
| state["bbox"] = (x0, y0, w, h) | |
| cv.rectangle(vis, (x0, y0), (x0 + w, y0 + h), (0, 0, 255), 2) | |
| return cv.cvtColor(vis, cv.COLOR_BGR2RGB) | |
| def clear_points(): | |
| """Reset selected points only.""" | |
| state["points"].clear() | |
| state["bbox"] = None | |
| if state["first_frame"] is None: | |
| return None | |
| return cv.cvtColor(state["first_frame"], cv.COLOR_BGR2RGB) | |
| def clear_all(): | |
| """Reset everything.""" | |
| state["points"].clear() | |
| state["bbox"] = None | |
| state["video_path"] = None | |
| state["first_frame"] = None | |
| return None, None, None | |
| def track_video(): | |
| """Init VitTrack and process entire video, return output path.""" | |
| if state["video_path"] is None or state["bbox"] is None: | |
| return None | |
| # instantiate VitTrack | |
| model = VitTrack( | |
| model_path=MODEL_PATH, | |
| backend_id=backend_id, | |
| target_id= target_id | |
| ) | |
| cap = cv.VideoCapture(state["video_path"]) | |
| fps = cap.get(cv.CAP_PROP_FPS) | |
| w = int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) | |
| # prepare temporary output file | |
| tmpdir = tempfile.gettempdir() | |
| out_path = os.path.join(tmpdir, "vittrack_output.mp4") | |
| writer = cv.VideoWriter( | |
| out_path, | |
| cv.VideoWriter_fourcc(*"mp4v"), | |
| fps, | |
| (w, h) | |
| ) | |
| # read & init on first frame | |
| _, first_frame = cap.read() | |
| model.init(first_frame, state["bbox"]) | |
| tm = cv.TickMeter() | |
| while True: | |
| has_frame, frame = cap.read() | |
| if not has_frame: | |
| break | |
| tm.start() | |
| isLocated, bbox, score = model.infer(frame) | |
| tm.stop() | |
| vis = frame.copy() | |
| # overlay FPS | |
| cv.putText(vis, f"FPS:{tm.getFPS():.2f}", (w//4, 30), | |
| cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| # draw tracking box or loss message | |
| if isLocated and score >= 0.3: | |
| x, y, w_, h_ = bbox | |
| cv.rectangle(vis, (x, y), (x + w_, y + h_), (0, 255, 0), 2) | |
| cv.putText(vis, f"{score:.2f}", (x, y - 10), | |
| cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
| else: | |
| cv.putText(vis, "Target lost!", | |
| (w // 2, h//4), | |
| cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3) | |
| writer.write(vis) | |
| tm.reset() | |
| cap.release() | |
| writer.release() | |
| return out_path | |
| def example_pipeline(video_path): | |
| clear_all() | |
| filename = video_path.split('/')[-1] | |
| state["video_path"] = video_path | |
| state["bbox"] = eval(bbox_dict[filename]) | |
| return track_video() | |
| with gr.Blocks(css='''.example * { | |
| font-style: italic; | |
| font-size: 18px !important; | |
| color: #0ea5e9 !important; | |
| }''') as demo: | |
| gr.Markdown("## VitTrack: Interactive Video Object Tracking") | |
| gr.Markdown( | |
| """ | |
| **How to use this tool:** | |
| 1. **Upload a video** file (e.g., `.mp4` or `.avi`). | |
| 2. The **first frame** of the video will appear. | |
| 3. **Click exactly 4 points** on the object you want to track. These points should outline the object as closely as possible. | |
| 4. A **bounding box** will be drawn around the selected region automatically. | |
| 5. Click the **Track** button to start object tracking across the entire video. | |
| 6. The output video with tracking overlay will appear below. | |
| You can also use: | |
| - ๐งน **Clear Points** to reset the 4-point selection on the first frame. | |
| - ๐ **Clear All** to reset the uploaded video, frame, and selections. | |
| """ | |
| ) | |
| with gr.Row(): | |
| video_in = gr.Video(label="Upload Video") | |
| first_frame = gr.Image(label="First Frame", interactive=True) | |
| output_video = gr.Video(label="Tracking Result") | |
| with gr.Row(): | |
| track_btn = gr.Button("Track", variant="primary") | |
| clear_pts_btn = gr.Button("Clear Points") | |
| clear_all_btn = gr.Button("Clear All") | |
| gr.Markdown("Click any row to load an example.", elem_classes=["example"]) | |
| examples = [ | |
| [car_on_road_video], | |
| [car_in_desert_video], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[video_in], | |
| outputs=[output_video], | |
| fn=example_pipeline, | |
| cache_examples=False, | |
| run_on_click=True | |
| ) | |
| gr.Markdown("Example videos credit: https://pixabay.com/") | |
| video_in.change(fn=load_first_frame, inputs=video_in, outputs=first_frame) | |
| first_frame.select(fn=select_point, inputs=first_frame, outputs=first_frame) | |
| clear_pts_btn.click(fn=clear_points, outputs=first_frame) | |
| clear_all_btn.click(fn=clear_all, outputs=[video_in, first_frame, output_video]) | |
| track_btn.click(fn=track_video, outputs=output_video) | |
| if __name__ == "__main__": | |
| demo.launch() | |