image_to_pixle_params_yoloSAM/ultralytics-main/ultralytics/solutions/speed_estimation.py

118 lines
5.7 KiB
Python
Raw Normal View History

2025-07-14 17:36:53 +08:00
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from collections import deque
from math import sqrt
from typing import Any
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
from ultralytics.utils.plotting import colors
class SpeedEstimator(BaseSolution):
"""
A class to estimate the speed of objects in a real-time video stream based on their tracks.
This class extends the BaseSolution class and provides functionality for estimating object speeds using
tracking data in video streams. Speed is calculated based on pixel displacement over time and converted
to real-world units using a configurable meters-per-pixel scale factor.
Attributes:
fps (float): Video frame rate for time calculations.
frame_count (int): Global frame counter for tracking temporal information.
trk_frame_ids (dict): Maps track IDs to their first frame index.
spd (dict): Final speed per object in km/h once locked.
trk_hist (dict): Maps track IDs to deque of position history.
locked_ids (set): Track IDs whose speed has been finalized.
max_hist (int): Required frame history before computing speed.
meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.
max_speed (int): Maximum allowed object speed; values above this will be capped.
Methods:
process: Process input frames to estimate object speeds based on tracking data.
store_tracking_history: Store the tracking history for an object.
extract_tracks: Extract tracks from the current frame.
display_output: Display the output with annotations.
Examples:
Initialize speed estimator and process a frame
>>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)
>>> frame = cv2.imread("frame.jpg")
>>> results = estimator.process(frame)
>>> cv2.imshow("Speed Estimation", results.plot_im)
"""
def __init__(self, **kwargs: Any) -> None:
"""
Initialize the SpeedEstimator object with speed estimation parameters and data structures.
Args:
**kwargs (Any): Additional keyword arguments passed to the parent class.
"""
super().__init__(**kwargs)
self.fps = self.CFG["fps"] # Video frame rate for time calculations
self.frame_count = 0 # Global frame counter
self.trk_frame_ids = {} # Track ID → first frame index
self.spd = {} # Final speed per object (km/h), once locked
self.trk_hist = {} # Track ID → deque of (time, position)
self.locked_ids = set() # Track IDs whose speed has been finalized
self.max_hist = self.CFG["max_hist"] # Required frame history before computing speed
self.meter_per_pixel = self.CFG["meter_per_pixel"] # Scene scale, depends on camera details
self.max_speed = self.CFG["max_speed"] # Maximum speed adjustment
def process(self, im0) -> SolutionResults:
"""
Process an input frame to estimate object speeds based on tracking data.
Args:
im0 (np.ndarray): Input image for processing with shape (H, W, C) for RGB images.
Returns:
(SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).
Examples:
Process a frame for speed estimation
>>> estimator = SpeedEstimator()
>>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
>>> results = estimator.process(image)
"""
self.frame_count += 1
self.extract_tracks(im0)
annotator = SolutionAnnotator(im0, line_width=self.line_width)
for box, track_id, _, _ in zip(self.boxes, self.track_ids, self.clss, self.confs):
self.store_tracking_history(track_id, box)
if track_id not in self.trk_hist: # Initialize history if new track found
self.trk_hist[track_id] = deque(maxlen=self.max_hist)
self.trk_frame_ids[track_id] = self.frame_count
if track_id not in self.locked_ids: # Update history until speed is locked
trk_hist = self.trk_hist[track_id]
trk_hist.append(self.track_line[-1])
# Compute and lock speed once enough history is collected
if len(trk_hist) == self.max_hist:
p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track
dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds
if dt > 0:
dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement
pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance
meters = pixel_distance * self.meter_per_pixel # Convert to meters
self.spd[track_id] = int(
min((meters / dt) * 3.6, self.max_speed)
) # Convert to km/h and store final speed
self.locked_ids.add(track_id) # Prevent further updates
self.trk_hist.pop(track_id, None) # Free memory
self.trk_frame_ids.pop(track_id, None) # Remove frame start reference
if track_id in self.spd:
speed_label = f"{self.spd[track_id]} km/h"
annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box
plot_im = annotator.result()
self.display_output(plot_im) # Display output with base class function
# Return results with processed image and tracking summary
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))