# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license import torch from ultralytics.models.yolo.detect.predict import DetectionPredictor from ultralytics.utils import ops class NASPredictor(DetectionPredictor): """ Ultralytics YOLO NAS Predictor for object detection. This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the bounding boxes to fit the original image dimensions. Attributes: args (Namespace): Namespace containing various configurations for post-processing including confidence threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options. model (torch.nn.Module): The YOLO NAS model used for inference. batch (list): Batch of inputs for processing. Examples: >>> from ultralytics import NAS >>> model = NAS("yolo_nas_s") >>> predictor = model.predictor Assume that raw_preds, img, orig_imgs are available >>> results = predictor.postprocess(raw_preds, img, orig_imgs) Notes: Typically, this class is not instantiated directly. It is used internally within the NAS class. """ def postprocess(self, preds_in, img, orig_imgs): """ Postprocess NAS model predictions to generate final detection results. This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies post-processing operations to generate the final detection results compatible with Ultralytics result visualization and analysis tools. Args: preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores. img (torch.Tensor): Input image tensor that was fed to the model, with shape (B, C, H, W). orig_imgs (list | torch.Tensor | np.ndarray): Original images before preprocessing, used for scaling coordinates back to original dimensions. Returns: (list): List of Results objects containing the processed predictions for each image in the batch. Examples: >>> predictor = NAS("yolo_nas_s").predictor >>> results = predictor.postprocess(raw_preds, img, orig_imgs) """ boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores return super().postprocess(preds, img, orig_imgs)