Tracking objects in video is a thoroughly studied problem in computer vision that has important applications in industries like sports, retail and security. There are several possible approaches to this problem, but a popular one that’s both simple to implement and effective in practice is called tracking-by-detection.
The tracking-by-detection paradigm relies heavily on high quality object detectors. This means it can leverage advances in deep learning that have dramatically improved the performance of these models.
In this post, we’ll walk through an implementation of a simplified tracking-by-detection algorithm that uses an off-the-shelf detector available for PyTorch. If you want to play with the code, check out the algorithm or the visualization on GitHub.
How it works
Here’s the algorithm. For each frame:
- Run the detector to find the objects in the image.
- Extract features for the objects you care about.
- Compute the pairwise cost between each object from the previous frame and each object in the current frame.
- Assign matches between the two frames in a way that minimizes the overall cost.
Here’s an implementation in Python (or check out the repo):
import json
import fire
import imageio
import numpy as np
import torch
from scipy.optimize import linear_sum_assignment
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms.functional import to_tensor
def track(
    video_path: str,
    output_path: str = "out.json",
    score_threshold: float = 0.5,
    class_index: int = 1,  # Track people by default
) -> None:
    """Track the objects for a specific class in a given video"""
    # Initialization
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = fasterrcnn_resnet50_fpn(pretrained=True).eval().to(device)
    reader = imageio.get_reader(video_path)
    # Tracking loop
    active_tracklets = []
    finished_tracklets = []
    prev_boxes = []
    for i, frame in enumerate(reader):
        height, width = frame.shape[:2]
        # Detection
        x = to_tensor(frame).to(device)
        result = model(x[None])[0]
        # Feature extraction: (x1, y1, x2, y2) in image coordinates
        # where class == class_index and score > score_threshold
        mask = torch.logical_and(
            result["labels"] == class_index, result["scores"] > score_threshold
        )
        boxes = result["boxes"][mask].data.cpu().numpy() / np.array(
            [width, height, width, height]
        )
        prev_indices = []
        boxes_indices = []
        if len(boxes) > 0 and len(prev_boxes) > 0:
            # Pairwise cost: euclidean distance between boxes
            cost = np.linalg.norm(prev_boxes[:, None] - boxes[None], axis=-1)
            # Bipartite matching
            prev_indices, boxes_indices = linear_sum_assignment(cost)
        # Add matches to active tracklets
        for prev_idx, box_idx in zip(prev_indices, boxes_indices):
            active_tracklets[prev_idx]["boxes"].append(
                np.round(boxes[box_idx], 3).tolist()
            )
        # Finalize lost tracklets
        lost_indices = set(range(len(active_tracklets))) - set(prev_indices)
        for lost_idx in sorted(lost_indices, reverse=True):
            finished_tracklets.append(active_tracklets.pop(lost_idx))
        # Activate new tracklets
        new_indices = set(range(len(boxes))) - set(boxes_indices)
        for new_idx in new_indices:
            active_tracklets.append(
                {"start": i, "boxes": [np.round(boxes[new_idx], 3).tolist()]}
            )
        # "Predict" next frame for comparison
        prev_boxes = np.array([tracklet["boxes"][-1] for tracklet in active_tracklets])
    with open(output_path, "w") as f:
        f.write(
            json.dumps(
                {
                    "fps": reader.get_meta_data()["fps"],
                    "tracklets": finished_tracklets + active_tracklets,
                }
            )
        )
if __name__ == "__main__":
    fire.Fire(track)Let’s walk through the important bits here.
Object detection
The torchvision package provides pre-trained weights and implementations for several computer vision models, including Faster R-CNN. It’s easy to run these models on arbitrary images:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = fasterrcnn_resnet50_fpn(pretrained=True).eval().to(device)
...
x = to_tensor(frame).to(device)
result = model(x[None])[0]This snippet runs inference on a GPU if it’s available, otherwise it uses the CPU. It also uses the to_tensor() helper from torchvision to convert from an RGB image with shape (height, width, channels) with 8-bit integers to a (channels, height, width) tensor with 32-bit floats between 0.0 and 1.0.
The result you get back from calling the model object will be a dictionary with PyTorch tensors for boxes, labels (in the form of class indices from the COCO dataset) and scores (confidence values for each detection). I’ve got another example of using torchvision for inference at this image classification quick start.
Feature extraction
Next, we need to get features for each object. And good news: the coordinates of a bounding box are useful features for object tracking! We just want to make sure we’re only including objects that have the correct class and a high confidence score. For this we access the result dictionary.
mask = torch.logical_and(
    result["labels"] == class_index, result["scores"] > score_threshold
)
boxes = result["boxes"][mask].data.cpu().numpy() / np.array(
    [width, height, width, height]
)Here we’re creating a boolean tensor called mask for objects where the label is equal to the class label we set (defaults to 1 for the person class) and the confidence is high enough (defaults to 0.5). Then we pull those boxes off the GPU (if we’re using it). We also divide by the width and height of the image to keep the coordinates in relative pixel coordinates. This is a useful default because it prevents you from needing to worry about the size of the image, which is somewhat arbitrary.
Pairwise cost
Once we have object features, we can compute pairwise costs with objects from the previous frame (assuming we’re beyond the first frame).
cost = np.linalg.norm(prev_boxes[:, None] - boxes[None], axis=-1)This is a one-liner for computing the Euclidean distance between pairs of boxes. If prev_boxes is (7, 4) and boxes is (8, 4) then the resulting cost matrix will be (7, 8). Euclidean distance is convenient, but the NumPy axes trick works for arbitrary functions and it works in PyTorch directly (read more about it here). Intersection over union is probably a better metric for comparing box coordinates and can be vectorized (but is more verbose than Euclidean distance).
Another popular approach here is to use appearance features from the objects themselves. This makes it easier to recover when tracklets (a tracklet is a sequence of boxes belonging to the same object) are lost.
Matching
Given a cost matrix, we need a way to figure out a set of assignments that minimizes cost. Typically, we also want to constrain the solution to make sure no object gets more than one assignment. The Munkres assignment algorithm (also known as the Hungarian algorithm) does exactly this. SciPy comes with this algorithm available in the linear_sum_assignment function.
prev_indices, boxes_indices = linear_sum_assignment(cost)After this call, the nth element of prev_indices will be the index of the previous boxes that lines up with the current box indexed by the nth element of boxes_indices. This means if we zip prev_indices and box_indices we will be able to access the correct index from both sets of objects together.
Bookkeeping
Finally, there’s some bookkeeping to do on every step, which looks different depending on the data structure you use for your tracklets. The three cases that need to be handled are:
- There’s a match between an existing object and a new object.
- There’s an existing object without a match (this tracklet will be considered lost).
- There’s a new object without a match (this starts a new tracklet).
For case (1), we just append the new object to the array of boxes in the matching tracklet. For case (2), we move the tracklet from the list of active tracklets to the list of lost or finished tracklets. For case (3), we create a new tracklet where start is the frame we’re currently on.
The format of the tracklet data structure in this implementation is simple: we just keep lists of active and finished tracklets. Each tracklet is a dictionary that looks like the following:
{
    "start": <the frame where the tracklet starts>,
    "boxes": <a (n_frames, 4) array of box coordinates>
}Validation
Validating the implementation of an algorithm like object tracking is tough, especially when you don’t have another implemenation you can compare against or ground truth data. Unit tests definitely help and are a good idea for anything you plan on using for a while.
Another thing you should be doing as much as possible is visualizing your results. But this is a little tricky when the input to the algorithm is video. Matplotlib does have some animation tooling, but it’s not very intuitive to use and the result isn’t interactive. But you can go a long way with basic knowledge of JavaScript and React. To make sure this algorithm was working as expected, I put together a simple React app that plots tracklets on top of a video.
You can check out the app here with my test video in it. You can also load your own video and output. Check out the code for the visualization here.
Conclusion
And that’s all there is to it. There are a lot of ways to make this algorithm more sophisticated, but if you can afford to run object detection on every frame, you can go a long way by just improving your detections (mostly by adding high quality labeled data).
