97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
import albumentations as A
|
|
from albumentations.pytorch import ToTensorV2
|
|
from torch.nn import functional as F
|
|
from tqdm import tqdm
|
|
import shutil
|
|
|
|
from torch.serialization import add_safe_globals
|
|
from efficientnet_pytorch.model import EfficientNet as EfficientNetClass
|
|
add_safe_globals([EfficientNetClass])
|
|
|
|
|
|
model_path = "model_save/best_model.pth"
|
|
image_root = "test_images" # 原始图片路径(支持子目录)
|
|
output_root = "classified_images" # 输出结果文件夹
|
|
HW = 260
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
idx_to_label = {
|
|
0: "front",
|
|
1: "rear",
|
|
2: "side",
|
|
3: "front-left",
|
|
4: "front-right",
|
|
5: "rear-left",
|
|
6: "rear-right",
|
|
7: "others"
|
|
}
|
|
num_classes = len(idx_to_label)
|
|
τ_others = 0.75
|
|
|
|
# ---------- 图像预处理 ----------
|
|
val_transform = A.Compose([
|
|
A.Resize(HW, HW),
|
|
A.Normalize(mean=(0.485, 0.456, 0.406),
|
|
std=(0.229, 0.224, 0.225)),
|
|
ToTensorV2()
|
|
])
|
|
|
|
# ---------- 加载模型 ----------
|
|
print("🔍 Loading model...")
|
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
|
model.eval().to(device)
|
|
|
|
# ---------- 收集所有图片路径 ----------
|
|
def collect_image_paths(root_dir):
|
|
image_paths = []
|
|
for dirpath, _, filenames in os.walk(root_dir):
|
|
for filename in filenames:
|
|
if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
|
|
image_paths.append(os.path.join(dirpath, filename))
|
|
return image_paths
|
|
|
|
# ---------- 推理函数 ----------
|
|
def infer_and_save(image_paths):
|
|
os.makedirs(output_root, exist_ok=True)
|
|
for label in idx_to_label.values():
|
|
os.makedirs(os.path.join(output_root, label), exist_ok=True)
|
|
|
|
for img_path in tqdm(image_paths, desc="推理中"):
|
|
try:
|
|
image = np.array(Image.open(img_path).convert("RGB"))
|
|
image = val_transform(image=image)["image"].unsqueeze(0).to(device)
|
|
|
|
with torch.no_grad():
|
|
logits = model(image)
|
|
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
|
|
|
sorted_indices = np.argsort(probs)[::-1]
|
|
top1_idx = sorted_indices[0]
|
|
top1_prob = probs[top1_idx]
|
|
|
|
if top1_idx == 7 and top1_prob < τ_others:
|
|
pred_idx = sorted_indices[1] # 置信度不高 → 用次高
|
|
else:
|
|
pred_idx = top1_idx
|
|
|
|
pred_label = idx_to_label[pred_idx]
|
|
|
|
# 复制图片到对应文件夹
|
|
filename = os.path.basename(img_path)
|
|
dst_path = os.path.join(output_root, pred_label, filename)
|
|
shutil.copy(img_path, dst_path)
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error processing {img_path}: {e}")
|
|
|
|
# ---------- 主程序 ----------
|
|
if __name__ == "__main__":
|
|
image_paths = collect_image_paths(image_root)
|
|
print(f"✅ Found {len(image_paths)} images.")
|
|
infer_and_save(image_paths)
|
|
print("🎉 分类完成 ✅")
|