car_viewpoint/inference.py

97 lines
3.1 KiB
Python

import os
import torch
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.nn import functional as F
from tqdm import tqdm
import shutil
from torch.serialization import add_safe_globals
from efficientnet_pytorch.model import EfficientNet as EfficientNetClass
add_safe_globals([EfficientNetClass])
model_path = "model_save/best_model.pth"
image_root = "test_images" # 原始图片路径(支持子目录)
output_root = "classified_images" # 输出结果文件夹
HW = 260
device = "cuda" if torch.cuda.is_available() else "cpu"
idx_to_label = {
0: "front",
1: "rear",
2: "side",
3: "front-left",
4: "front-right",
5: "rear-left",
6: "rear-right",
7: "others"
}
num_classes = len(idx_to_label)
τ_others = 0.75
# ---------- 图像预处理 ----------
val_transform = A.Compose([
A.Resize(HW, HW),
A.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
# ---------- 加载模型 ----------
print("🔍 Loading model...")
model = torch.load(model_path, map_location=device, weights_only=False)
model.eval().to(device)
# ---------- 收集所有图片路径 ----------
def collect_image_paths(root_dir):
image_paths = []
for dirpath, _, filenames in os.walk(root_dir):
for filename in filenames:
if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
image_paths.append(os.path.join(dirpath, filename))
return image_paths
# ---------- 推理函数 ----------
def infer_and_save(image_paths):
os.makedirs(output_root, exist_ok=True)
for label in idx_to_label.values():
os.makedirs(os.path.join(output_root, label), exist_ok=True)
for img_path in tqdm(image_paths, desc="推理中"):
try:
image = np.array(Image.open(img_path).convert("RGB"))
image = val_transform(image=image)["image"].unsqueeze(0).to(device)
with torch.no_grad():
logits = model(image)
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
sorted_indices = np.argsort(probs)[::-1]
top1_idx = sorted_indices[0]
top1_prob = probs[top1_idx]
if top1_idx == 7 and top1_prob < τ_others:
pred_idx = sorted_indices[1] # 置信度不高 → 用次高
else:
pred_idx = top1_idx
pred_label = idx_to_label[pred_idx]
# 复制图片到对应文件夹
filename = os.path.basename(img_path)
dst_path = os.path.join(output_root, pred_label, filename)
shutil.copy(img_path, dst_path)
except Exception as e:
print(f"❌ Error processing {img_path}: {e}")
# ---------- 主程序 ----------
if __name__ == "__main__":
image_paths = collect_image_paths(image_root)
print(f"✅ Found {len(image_paths)} images.")
infer_and_save(image_paths)
print("🎉 分类完成 ✅")