# u2net_saliency_only.py import torch import cv2 import torchvision.transforms as transforms import numpy as np from PIL import Image from model.u2net import U2NETP import os # 缓存模型,避免重复加载 _model = None def load_u2net_model(model_path="./saved_models/u2netp.pth"): global _model if _model is None: _model = U2NETP(3, 1) _model.load_state_dict(torch.load(model_path)) _model.cuda() _model.eval() return _model def preprocess(image): transform = transforms.Compose([ transforms.Resize((320, 320)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0) def postprocess(output, original_size): output = output.squeeze().cpu().detach().numpy() output = (output * 255).astype(np.uint8) output = cv2.resize(output, original_size) return output def generate_saliency_map(input_image_path, output_image_path, model_path="./saved_models/u2netp.pth"): net = load_u2net_model(model_path) original_image = Image.open(input_image_path).convert("RGB") original_size = original_image.size input_tensor = preprocess(original_image).cuda() output = net(input_tensor)[0] saliency_map = postprocess(output, original_size) cv2.imwrite(output_image_path, saliency_map) print(f"显著图已保存至: {output_image_path}") if __name__ == '__main__': triplets = [ # (标签, 原图路径, 显著 / 掩模 路径) ('front', './image/front.jpg', './saliency/front.png'), # 正面 ('rear', './image/rear.jpg', './saliency/rear.png'), # 后面 ('side', './image/side.jpg', './saliency/side.png'), # 侧面(做圆检测) ] thresh_dir = './thresh' os.makedirs(thresh_dir, exist_ok=True) # # ======================= 生成显著性图 (可以注释掉,在u2net_saliency生成)======================= for tag, image_path, saliency_path in triplets: print(f"处理 {tag} 图像中...") generate_saliency_map(image_path, saliency_path)