118 lines
5.7 KiB
Python
118 lines
5.7 KiB
Python
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||
|
|
||
|
from collections import deque
|
||
|
from math import sqrt
|
||
|
from typing import Any
|
||
|
|
||
|
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||
|
from ultralytics.utils.plotting import colors
|
||
|
|
||
|
|
||
|
class SpeedEstimator(BaseSolution):
|
||
|
"""
|
||
|
A class to estimate the speed of objects in a real-time video stream based on their tracks.
|
||
|
|
||
|
This class extends the BaseSolution class and provides functionality for estimating object speeds using
|
||
|
tracking data in video streams. Speed is calculated based on pixel displacement over time and converted
|
||
|
to real-world units using a configurable meters-per-pixel scale factor.
|
||
|
|
||
|
Attributes:
|
||
|
fps (float): Video frame rate for time calculations.
|
||
|
frame_count (int): Global frame counter for tracking temporal information.
|
||
|
trk_frame_ids (dict): Maps track IDs to their first frame index.
|
||
|
spd (dict): Final speed per object in km/h once locked.
|
||
|
trk_hist (dict): Maps track IDs to deque of position history.
|
||
|
locked_ids (set): Track IDs whose speed has been finalized.
|
||
|
max_hist (int): Required frame history before computing speed.
|
||
|
meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.
|
||
|
max_speed (int): Maximum allowed object speed; values above this will be capped.
|
||
|
|
||
|
Methods:
|
||
|
process: Process input frames to estimate object speeds based on tracking data.
|
||
|
store_tracking_history: Store the tracking history for an object.
|
||
|
extract_tracks: Extract tracks from the current frame.
|
||
|
display_output: Display the output with annotations.
|
||
|
|
||
|
Examples:
|
||
|
Initialize speed estimator and process a frame
|
||
|
>>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)
|
||
|
>>> frame = cv2.imread("frame.jpg")
|
||
|
>>> results = estimator.process(frame)
|
||
|
>>> cv2.imshow("Speed Estimation", results.plot_im)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, **kwargs: Any) -> None:
|
||
|
"""
|
||
|
Initialize the SpeedEstimator object with speed estimation parameters and data structures.
|
||
|
|
||
|
Args:
|
||
|
**kwargs (Any): Additional keyword arguments passed to the parent class.
|
||
|
"""
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
self.fps = self.CFG["fps"] # Video frame rate for time calculations
|
||
|
self.frame_count = 0 # Global frame counter
|
||
|
self.trk_frame_ids = {} # Track ID → first frame index
|
||
|
self.spd = {} # Final speed per object (km/h), once locked
|
||
|
self.trk_hist = {} # Track ID → deque of (time, position)
|
||
|
self.locked_ids = set() # Track IDs whose speed has been finalized
|
||
|
self.max_hist = self.CFG["max_hist"] # Required frame history before computing speed
|
||
|
self.meter_per_pixel = self.CFG["meter_per_pixel"] # Scene scale, depends on camera details
|
||
|
self.max_speed = self.CFG["max_speed"] # Maximum speed adjustment
|
||
|
|
||
|
def process(self, im0) -> SolutionResults:
|
||
|
"""
|
||
|
Process an input frame to estimate object speeds based on tracking data.
|
||
|
|
||
|
Args:
|
||
|
im0 (np.ndarray): Input image for processing with shape (H, W, C) for RGB images.
|
||
|
|
||
|
Returns:
|
||
|
(SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).
|
||
|
|
||
|
Examples:
|
||
|
Process a frame for speed estimation
|
||
|
>>> estimator = SpeedEstimator()
|
||
|
>>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||
|
>>> results = estimator.process(image)
|
||
|
"""
|
||
|
self.frame_count += 1
|
||
|
self.extract_tracks(im0)
|
||
|
annotator = SolutionAnnotator(im0, line_width=self.line_width)
|
||
|
|
||
|
for box, track_id, _, _ in zip(self.boxes, self.track_ids, self.clss, self.confs):
|
||
|
self.store_tracking_history(track_id, box)
|
||
|
|
||
|
if track_id not in self.trk_hist: # Initialize history if new track found
|
||
|
self.trk_hist[track_id] = deque(maxlen=self.max_hist)
|
||
|
self.trk_frame_ids[track_id] = self.frame_count
|
||
|
|
||
|
if track_id not in self.locked_ids: # Update history until speed is locked
|
||
|
trk_hist = self.trk_hist[track_id]
|
||
|
trk_hist.append(self.track_line[-1])
|
||
|
|
||
|
# Compute and lock speed once enough history is collected
|
||
|
if len(trk_hist) == self.max_hist:
|
||
|
p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track
|
||
|
dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds
|
||
|
if dt > 0:
|
||
|
dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement
|
||
|
pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance
|
||
|
meters = pixel_distance * self.meter_per_pixel # Convert to meters
|
||
|
self.spd[track_id] = int(
|
||
|
min((meters / dt) * 3.6, self.max_speed)
|
||
|
) # Convert to km/h and store final speed
|
||
|
self.locked_ids.add(track_id) # Prevent further updates
|
||
|
self.trk_hist.pop(track_id, None) # Free memory
|
||
|
self.trk_frame_ids.pop(track_id, None) # Remove frame start reference
|
||
|
|
||
|
if track_id in self.spd:
|
||
|
speed_label = f"{self.spd[track_id]} km/h"
|
||
|
annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box
|
||
|
|
||
|
plot_im = annotator.result()
|
||
|
self.display_output(plot_im) # Display output with base class function
|
||
|
|
||
|
# Return results with processed image and tracking summary
|
||
|
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
|