image_to_pixle_params_yoloSAM/ultralytics-main/test.py

107 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import os
import torch
from pathlib import Path
def detect_images():
# ======================= 配置区 =======================
# 模型配置
model_config = {
'model_path': 'yolo11x.pt', # 本地模型路径
'download_url':'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11x.pt' # 如果没有模型文件可在此处下载URL
}
# 路径配置
path_config = {
'input_folder': 'input',
'output_folder': 'output',
'auto_create_dir': True # 自动创建输出目录
}
# 推理参数
predict_config = {
'conf_thres': 0.25, # 置信度阈值
'iou_thres': 0.45, # IoU阈值
'imgsz': 640, # 输入分辨率
'line_width': 2, # 检测框线宽
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu' # 自动选择设备
}
# ====================== 配置结束 ======================
try:
# 验证输入目录
if not Path(path_config['input_folder']).exists():
raise FileNotFoundError(f"输入目录不存在: {path_config['input_folder']}")
# 自动创建输出目录
if path_config['auto_create_dir']:
Path(path_config['output_folder']).mkdir(parents=True, exist_ok=True)
# 加载模型(带异常捕获)
if not Path(model_config['model_path']).exists():
if model_config['download_url']:
print("开始下载模型...")
YOLO(model_config['download_url']).download(model_config['model_path'])
else:
raise FileNotFoundError(f"模型文件不存在: {model_config['model_path']}")
# 初始化模型
model = YOLO(model_config['model_path']).to(predict_config['device'])
print(f"✅ 模型加载成功 | 设备: {predict_config['device'].upper()}")
# 执行推理
results = model.predict(
source=path_config['input_folder'],
project=path_config['output_folder'],
name="exp",
save=True,
conf=predict_config['conf_thres'],
iou=predict_config['iou_thres'],
imgsz=predict_config['imgsz'],
line_width=predict_config['line_width'],
show_labels=True,
show_conf=True,
classes=[2] # 只检测car类别
)
# 只保留每张图片中面积最大的car框
output_txt = os.path.join(path_config['output_folder'], 'car_boxes.txt')
with open(output_txt, 'w') as f:
for result in results:
boxes = result.boxes.xyxy.cpu().numpy() # [N, 4]
path = result.path
img_name = os.path.basename(path)
if len(boxes) > 0:
# 计算面积
areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in boxes]
max_idx = areas.index(max(areas))
box = boxes[max_idx]
x1, y1, x2, y2 = box
points = [
(x1, y1), # 左上
(x2, y2), # 右下
]
f.write(f"{img_name} " + " ".join([f"{int(x)} {int(y)}" for x, y in points]) + "\n")
print(f"已保存car框坐标到: {output_txt}")
# 统计信息
success_count = len(results)
save_dir = Path(results[0].save_dir) if success_count > 0 else None
print(f"\n🔍 推理完成 | 处理图片: {success_count}")
print(f"📁 结果目录: {save_dir.resolve() if save_dir else ''}")
# 显示首张结果(可选)
if success_count > 0:
results[0].show()
except Exception as e:
print(f"\n❌ 发生错误: {str(e)}")
print("问题排查建议:")
print("1. 检查模型文件路径是否正确")
print("2. 确认图片目录包含支持的格式jpg/png等")
print("3. 查看CUDA是否可用如需GPU加速")
print("4. 确保输出目录有写入权限")
if __name__ == "__main__":
detect_images()