import os import torch import numpy as np from PIL import Image import albumentations as A from albumentations.pytorch import ToTensorV2 from torch.nn import functional as F from tqdm import tqdm import shutil from torch.serialization import add_safe_globals from efficientnet_pytorch.model import EfficientNet as EfficientNetClass add_safe_globals([EfficientNetClass]) model_path = "model_save/best_model.pth" image_root = "test_images" # 原始图片路径(支持子目录) output_root = "classified_images" # 输出结果文件夹 HW = 260 device = "cuda" if torch.cuda.is_available() else "cpu" idx_to_label = { 0: "front", 1: "rear", 2: "side", 3: "front-left", 4: "front-right", 5: "rear-left", 6: "rear-right", 7: "others" } num_classes = len(idx_to_label) τ_others = 0.75 # ---------- 图像预处理 ---------- val_transform = A.Compose([ A.Resize(HW, HW), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2() ]) # ---------- 加载模型 ---------- print("🔍 Loading model...") model = torch.load(model_path, map_location=device, weights_only=False) model.eval().to(device) # ---------- 收集所有图片路径 ---------- def collect_image_paths(root_dir): image_paths = [] for dirpath, _, filenames in os.walk(root_dir): for filename in filenames: if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")): image_paths.append(os.path.join(dirpath, filename)) return image_paths # ---------- 推理函数 ---------- def infer_and_save(image_paths): os.makedirs(output_root, exist_ok=True) for label in idx_to_label.values(): os.makedirs(os.path.join(output_root, label), exist_ok=True) for img_path in tqdm(image_paths, desc="推理中"): try: image = np.array(Image.open(img_path).convert("RGB")) image = val_transform(image=image)["image"].unsqueeze(0).to(device) with torch.no_grad(): logits = model(image) probs = F.softmax(logits, dim=1).cpu().numpy()[0] sorted_indices = np.argsort(probs)[::-1] top1_idx = sorted_indices[0] top1_prob = probs[top1_idx] if top1_idx == 7 and top1_prob < τ_others: pred_idx = sorted_indices[1] # 置信度不高 → 用次高 else: pred_idx = top1_idx pred_label = idx_to_label[pred_idx] # 复制图片到对应文件夹 filename = os.path.basename(img_path) dst_path = os.path.join(output_root, pred_label, filename) shutil.copy(img_path, dst_path) except Exception as e: print(f"❌ Error processing {img_path}: {e}") # ---------- 主程序 ---------- if __name__ == "__main__": image_paths = collect_image_paths(image_root) print(f"✅ Found {len(image_paths)} images.") infer_and_save(image_paths) print("🎉 分类完成 ✅")