Last active
May 22, 2024 15:40
-
-
Save Denbergvanthijs/0987cfefb37207b337fdbb59f30cb9a1 to your computer and use it in GitHub Desktop.
MVP for YOLOv8 Nano with webcam on CPU
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import cv2 | |
import numpy as np | |
from ultralytics import YOLO | |
from ultralytics.utils.plotting import Annotator, colors | |
# Based on "YOLOv8 Multi-Object Tracking in Videos" notebook of https://github.com/ultralytics/ultralytics | |
# Instal Dependencies in the command line (CMD): | |
# pip install ultralytics | |
width = 1280 # Width and height of frames to process, not output video resolution | |
height = 720 | |
capture_id = 1 | |
print("Loading model...") | |
model = YOLO("yolov8n.pt") # Choose model, will be downloaded automatically | |
names = model.model.names | |
# Set capture to webcam | |
print("Opening video...") | |
cap = cv2.VideoCapture(capture_id, cv2.CAP_DSHOW) | |
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG")) | |
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) | |
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) | |
cap.set(cv2.CAP_PROP_FPS, 30) | |
cap.set(cv2.CAP_PROP_AUTO_EXPOSURE, 3) # Auto exposure | |
assert cap.isOpened(), "Error reading video file" | |
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) | |
print(f"Video resolution: {w}x{h}, FPS: {fps}") | |
print("Processing video...") | |
track_history = defaultdict(lambda: []) | |
while cap.isOpened(): | |
success, frame = cap.read() | |
if success: # Only if frame is read successfully | |
results = model.track(frame, persist=True, verbose=False) # Inference: do prediction on current frame | |
boxes = results[0].boxes.xyxy.cpu() # Predicted bounding boxes | |
if results[0].boxes.id is not None: # If any objects are detected | |
classes = results[0].boxes.cls.cpu().tolist() # Predicted classes | |
track_ids = results[0].boxes.id.int().cpu().tolist() # Predicted track IDs | |
# confs = results[0].boxes.conf.float().cpu().tolist() # Predicted confidence scores | |
annotator = Annotator(frame, line_width=2) | |
for box, cls_id, track_id in zip(boxes, classes, track_ids): # For each detected object | |
colors_id = colors(int(cls_id), True) | |
annotator.box_label(box, color=colors_id, label=names[int(cls_id)]) | |
# Store tracking history | |
track = track_history[track_id] | |
track.append((int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2))) | |
if len(track) > 30: # Keep only the last 30 points in the track | |
track.pop(0) | |
# Plot tracks | |
points = np.array(track, dtype=np.int32).reshape((-1, 1, 2)) | |
cv2.circle(frame, (track[-1]), 7, colors_id, -1) | |
cv2.polylines(frame, [points], isClosed=False, color=colors_id, thickness=2) | |
# Display the resulting frame, resize to 1920x1080 | |
frame = cv2.resize(frame, (1920, 1080)) | |
cv2.imshow("Frame", frame) | |
if cv2.waitKey(1) & 0xFF == ord("q"): | |
break | |
else: | |
break | |
cap.release() | |
cv2.destroyAllWindows() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment