107 lines
4.1 KiB
Python
107 lines
4.1 KiB
Python
from ultralytics import YOLO
|
||
import os
|
||
import torch
|
||
from pathlib import Path
|
||
|
||
def detect_images():
|
||
# ======================= 配置区 =======================
|
||
# 模型配置
|
||
model_config = {
|
||
'model_path': 'yolo11x.pt', # 本地模型路径
|
||
'download_url':'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11x.pt' # 如果没有模型文件可在此处下载URL
|
||
}
|
||
|
||
# 路径配置
|
||
path_config = {
|
||
'input_folder': 'input',
|
||
'output_folder': 'output',
|
||
'auto_create_dir': True # 自动创建输出目录
|
||
}
|
||
|
||
# 推理参数
|
||
predict_config = {
|
||
'conf_thres': 0.25, # 置信度阈值
|
||
'iou_thres': 0.45, # IoU阈值
|
||
'imgsz': 640, # 输入分辨率
|
||
'line_width': 2, # 检测框线宽
|
||
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu' # 自动选择设备
|
||
}
|
||
# ====================== 配置结束 ======================
|
||
|
||
try:
|
||
# 验证输入目录
|
||
if not Path(path_config['input_folder']).exists():
|
||
raise FileNotFoundError(f"输入目录不存在: {path_config['input_folder']}")
|
||
|
||
# 自动创建输出目录
|
||
if path_config['auto_create_dir']:
|
||
Path(path_config['output_folder']).mkdir(parents=True, exist_ok=True)
|
||
|
||
# 加载模型(带异常捕获)
|
||
if not Path(model_config['model_path']).exists():
|
||
if model_config['download_url']:
|
||
print("开始下载模型...")
|
||
YOLO(model_config['download_url']).download(model_config['model_path'])
|
||
else:
|
||
raise FileNotFoundError(f"模型文件不存在: {model_config['model_path']}")
|
||
|
||
# 初始化模型
|
||
model = YOLO(model_config['model_path']).to(predict_config['device'])
|
||
print(f"✅ 模型加载成功 | 设备: {predict_config['device'].upper()}")
|
||
|
||
# 执行推理
|
||
results = model.predict(
|
||
source=path_config['input_folder'],
|
||
project=path_config['output_folder'],
|
||
name="exp",
|
||
save=True,
|
||
conf=predict_config['conf_thres'],
|
||
iou=predict_config['iou_thres'],
|
||
imgsz=predict_config['imgsz'],
|
||
line_width=predict_config['line_width'],
|
||
show_labels=True,
|
||
show_conf=True,
|
||
classes=[2] # 只检测car类别
|
||
)
|
||
|
||
# 只保留每张图片中面积最大的car框
|
||
output_txt = os.path.join(path_config['output_folder'], 'car_boxes.txt')
|
||
with open(output_txt, 'w') as f:
|
||
for result in results:
|
||
boxes = result.boxes.xyxy.cpu().numpy() # [N, 4]
|
||
path = result.path
|
||
img_name = os.path.basename(path)
|
||
if len(boxes) > 0:
|
||
# 计算面积
|
||
areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in boxes]
|
||
max_idx = areas.index(max(areas))
|
||
box = boxes[max_idx]
|
||
x1, y1, x2, y2 = box
|
||
points = [
|
||
(x1, y1), # 左上
|
||
(x2, y2), # 右下
|
||
]
|
||
f.write(f"{img_name} " + " ".join([f"{int(x)} {int(y)}" for x, y in points]) + "\n")
|
||
print(f"已保存car框坐标到: {output_txt}")
|
||
|
||
# 统计信息
|
||
success_count = len(results)
|
||
save_dir = Path(results[0].save_dir) if success_count > 0 else None
|
||
print(f"\n🔍 推理完成 | 处理图片: {success_count} 张")
|
||
print(f"📁 结果目录: {save_dir.resolve() if save_dir else '无'}")
|
||
|
||
# 显示首张结果(可选)
|
||
if success_count > 0:
|
||
results[0].show()
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 发生错误: {str(e)}")
|
||
print("问题排查建议:")
|
||
print("1. 检查模型文件路径是否正确")
|
||
print("2. 确认图片目录包含支持的格式(jpg/png等)")
|
||
print("3. 查看CUDA是否可用(如需GPU加速)")
|
||
print("4. 确保输出目录有写入权限")
|
||
|
||
if __name__ == "__main__":
|
||
detect_images()
|