107 lines
4.1 KiB
Python
107 lines
4.1 KiB
Python
|
from ultralytics import YOLO
|
|||
|
import os
|
|||
|
import torch
|
|||
|
from pathlib import Path
|
|||
|
|
|||
|
def detect_images():
|
|||
|
# ======================= 配置区 =======================
|
|||
|
# 模型配置
|
|||
|
model_config = {
|
|||
|
'model_path': 'yolo11x.pt', # 本地模型路径
|
|||
|
'download_url':'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11x.pt' # 如果没有模型文件可在此处下载URL
|
|||
|
}
|
|||
|
|
|||
|
# 路径配置
|
|||
|
path_config = {
|
|||
|
'input_folder': 'input',
|
|||
|
'output_folder': 'output',
|
|||
|
'auto_create_dir': True # 自动创建输出目录
|
|||
|
}
|
|||
|
|
|||
|
# 推理参数
|
|||
|
predict_config = {
|
|||
|
'conf_thres': 0.25, # 置信度阈值
|
|||
|
'iou_thres': 0.45, # IoU阈值
|
|||
|
'imgsz': 640, # 输入分辨率
|
|||
|
'line_width': 2, # 检测框线宽
|
|||
|
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu' # 自动选择设备
|
|||
|
}
|
|||
|
# ====================== 配置结束 ======================
|
|||
|
|
|||
|
try:
|
|||
|
# 验证输入目录
|
|||
|
if not Path(path_config['input_folder']).exists():
|
|||
|
raise FileNotFoundError(f"输入目录不存在: {path_config['input_folder']}")
|
|||
|
|
|||
|
# 自动创建输出目录
|
|||
|
if path_config['auto_create_dir']:
|
|||
|
Path(path_config['output_folder']).mkdir(parents=True, exist_ok=True)
|
|||
|
|
|||
|
# 加载模型(带异常捕获)
|
|||
|
if not Path(model_config['model_path']).exists():
|
|||
|
if model_config['download_url']:
|
|||
|
print("开始下载模型...")
|
|||
|
YOLO(model_config['download_url']).download(model_config['model_path'])
|
|||
|
else:
|
|||
|
raise FileNotFoundError(f"模型文件不存在: {model_config['model_path']}")
|
|||
|
|
|||
|
# 初始化模型
|
|||
|
model = YOLO(model_config['model_path']).to(predict_config['device'])
|
|||
|
print(f"✅ 模型加载成功 | 设备: {predict_config['device'].upper()}")
|
|||
|
|
|||
|
# 执行推理
|
|||
|
results = model.predict(
|
|||
|
source=path_config['input_folder'],
|
|||
|
project=path_config['output_folder'],
|
|||
|
name="exp",
|
|||
|
save=True,
|
|||
|
conf=predict_config['conf_thres'],
|
|||
|
iou=predict_config['iou_thres'],
|
|||
|
imgsz=predict_config['imgsz'],
|
|||
|
line_width=predict_config['line_width'],
|
|||
|
show_labels=True,
|
|||
|
show_conf=True,
|
|||
|
classes=[2] # 只检测car类别
|
|||
|
)
|
|||
|
|
|||
|
# 只保留每张图片中面积最大的car框
|
|||
|
output_txt = os.path.join(path_config['output_folder'], 'car_boxes.txt')
|
|||
|
with open(output_txt, 'w') as f:
|
|||
|
for result in results:
|
|||
|
boxes = result.boxes.xyxy.cpu().numpy() # [N, 4]
|
|||
|
path = result.path
|
|||
|
img_name = os.path.basename(path)
|
|||
|
if len(boxes) > 0:
|
|||
|
# 计算面积
|
|||
|
areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in boxes]
|
|||
|
max_idx = areas.index(max(areas))
|
|||
|
box = boxes[max_idx]
|
|||
|
x1, y1, x2, y2 = box
|
|||
|
points = [
|
|||
|
(x1, y1), # 左上
|
|||
|
(x2, y2), # 右下
|
|||
|
]
|
|||
|
f.write(f"{img_name} " + " ".join([f"{int(x)} {int(y)}" for x, y in points]) + "\n")
|
|||
|
print(f"已保存car框坐标到: {output_txt}")
|
|||
|
|
|||
|
# 统计信息
|
|||
|
success_count = len(results)
|
|||
|
save_dir = Path(results[0].save_dir) if success_count > 0 else None
|
|||
|
print(f"\n🔍 推理完成 | 处理图片: {success_count} 张")
|
|||
|
print(f"📁 结果目录: {save_dir.resolve() if save_dir else '无'}")
|
|||
|
|
|||
|
# 显示首张结果(可选)
|
|||
|
if success_count > 0:
|
|||
|
results[0].show()
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"\n❌ 发生错误: {str(e)}")
|
|||
|
print("问题排查建议:")
|
|||
|
print("1. 检查模型文件路径是否正确")
|
|||
|
print("2. 确认图片目录包含支持的格式(jpg/png等)")
|
|||
|
print("3. 查看CUDA是否可用(如需GPU加速)")
|
|||
|
print("4. 确保输出目录有写入权限")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
detect_images()
|