Abhishek Gola
Added samples
339a69e
import cv2 as cv
import numpy as np
import gradio as gr
from vittrack import VitTrack
from huggingface_hub import hf_hub_download
import os
import tempfile
# Download ONNX model at startup
MODEL_PATH = hf_hub_download(
repo_id="opencv/object_tracking_vittrack",
filename="object_tracking_vittrack_2023sep.onnx"
)
backend_id = cv.dnn.DNN_BACKEND_OPENCV
target_id = cv.dnn.DNN_TARGET_CPU
car_on_road_video = "examples/car.mp4"
car_in_desert_video = "examples/desert_car.mp4"
# Global state
state = {
"points": [],
"bbox": None,
"video_path": None,
"first_frame": None
}
#Example bounding boxes
bbox_dict = {
"car.mp4": "(152, 356, 332, 104)",
"desert_car.mp4": "(758, 452, 119, 65)",
}
def load_first_frame(video_path):
"""Load video, grab first frame, reset state."""
state["video_path"] = video_path
cap = cv.VideoCapture(video_path)
has_frame, frame = cap.read()
cap.release()
if not has_frame:
return None
state["first_frame"] = frame.copy()
return cv.cvtColor(frame, cv.COLOR_BGR2RGB)
def select_point(img, evt: gr.SelectData):
"""Accumulate up to 4 clicks, draw polygon + bounding box."""
if state["first_frame"] is None:
return None
x, y = int(evt.index[0]), int(evt.index[1])
if len(state["points"]) < 4:
state["points"].append((x, y))
vis = state["first_frame"].copy()
# draw each point
for pt in state["points"]:
cv.circle(vis, pt, 5, (0, 255, 0), -1)
# draw connecting polygon
if len(state["points"]) > 1:
pts = np.array(state["points"], dtype=np.int32)
cv.polylines(vis, [pts], isClosed=False, color=(255, 255, 0), thickness=2)
# once we have exactly 4, compute & draw bounding rect
if len(state["points"]) == 4:
pts = np.array(state["points"], dtype=np.int32)
x0, y0, w, h = cv.boundingRect(pts)
state["bbox"] = (x0, y0, w, h)
cv.rectangle(vis, (x0, y0), (x0 + w, y0 + h), (0, 0, 255), 2)
return cv.cvtColor(vis, cv.COLOR_BGR2RGB)
def clear_points():
"""Reset selected points only."""
state["points"].clear()
state["bbox"] = None
if state["first_frame"] is None:
return None
return cv.cvtColor(state["first_frame"], cv.COLOR_BGR2RGB)
def clear_all():
"""Reset everything."""
state["points"].clear()
state["bbox"] = None
state["video_path"] = None
state["first_frame"] = None
return None, None, None
def track_video():
"""Init VitTrack and process entire video, return output path."""
if state["video_path"] is None or state["bbox"] is None:
return None
# instantiate VitTrack
model = VitTrack(
model_path=MODEL_PATH,
backend_id=backend_id,
target_id= target_id
)
cap = cv.VideoCapture(state["video_path"])
fps = cap.get(cv.CAP_PROP_FPS)
w = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
# prepare temporary output file
tmpdir = tempfile.gettempdir()
out_path = os.path.join(tmpdir, "vittrack_output.mp4")
writer = cv.VideoWriter(
out_path,
cv.VideoWriter_fourcc(*"mp4v"),
fps,
(w, h)
)
# read & init on first frame
_, first_frame = cap.read()
model.init(first_frame, state["bbox"])
tm = cv.TickMeter()
while True:
has_frame, frame = cap.read()
if not has_frame:
break
tm.start()
isLocated, bbox, score = model.infer(frame)
tm.stop()
vis = frame.copy()
# overlay FPS
cv.putText(vis, f"FPS:{tm.getFPS():.2f}", (w//4, 30),
cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
# draw tracking box or loss message
if isLocated and score >= 0.3:
x, y, w_, h_ = bbox
cv.rectangle(vis, (x, y), (x + w_, y + h_), (0, 255, 0), 2)
cv.putText(vis, f"{score:.2f}", (x, y - 10),
cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
else:
cv.putText(vis, "Target lost!",
(w // 2, h//4),
cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)
writer.write(vis)
tm.reset()
cap.release()
writer.release()
return out_path
def example_pipeline(video_path):
clear_all()
filename = video_path.split('/')[-1]
state["video_path"] = video_path
state["bbox"] = eval(bbox_dict[filename])
return track_video()
with gr.Blocks(css='''.example * {
font-style: italic;
font-size: 18px !important;
color: #0ea5e9 !important;
}''') as demo:
gr.Markdown("## VitTrack: Interactive Video Object Tracking")
gr.Markdown(
"""
**How to use this tool:**
1. **Upload a video** file (e.g., `.mp4` or `.avi`).
2. The **first frame** of the video will appear.
3. **Click exactly 4 points** on the object you want to track. These points should outline the object as closely as possible.
4. A **bounding box** will be drawn around the selected region automatically.
5. Click the **Track** button to start object tracking across the entire video.
6. The output video with tracking overlay will appear below.
You can also use:
- ๐Ÿงน **Clear Points** to reset the 4-point selection on the first frame.
- ๐Ÿ”„ **Clear All** to reset the uploaded video, frame, and selections.
"""
)
with gr.Row():
video_in = gr.Video(label="Upload Video")
first_frame = gr.Image(label="First Frame", interactive=True)
output_video = gr.Video(label="Tracking Result")
with gr.Row():
track_btn = gr.Button("Track", variant="primary")
clear_pts_btn = gr.Button("Clear Points")
clear_all_btn = gr.Button("Clear All")
gr.Markdown("Click any row to load an example.", elem_classes=["example"])
examples = [
[car_on_road_video],
[car_in_desert_video],
]
gr.Examples(
examples=examples,
inputs=[video_in],
outputs=[output_video],
fn=example_pipeline,
cache_examples=False,
run_on_click=True
)
gr.Markdown("Example videos credit: https://pixabay.com/")
video_in.change(fn=load_first_frame, inputs=video_in, outputs=first_frame)
first_frame.select(fn=select_point, inputs=first_frame, outputs=first_frame)
clear_pts_btn.click(fn=clear_points, outputs=first_frame)
clear_all_btn.click(fn=clear_all, outputs=[video_in, first_frame, output_video])
track_btn.click(fn=track_video, outputs=output_video)
if __name__ == "__main__":
demo.launch()