68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
# u2net_saliency_only.py
|
||
import torch
|
||
import cv2
|
||
import torchvision.transforms as transforms
|
||
import numpy as np
|
||
from PIL import Image
|
||
from model.u2net import U2NETP
|
||
import os
|
||
# 缓存模型,避免重复加载
|
||
_model = None
|
||
|
||
|
||
def load_u2net_model(model_path="./saved_models/u2netp.pth"):
|
||
global _model
|
||
if _model is None:
|
||
_model = U2NETP(3, 1)
|
||
_model.load_state_dict(torch.load(model_path))
|
||
_model.cuda()
|
||
_model.eval()
|
||
return _model
|
||
|
||
|
||
def preprocess(image):
|
||
transform = transforms.Compose([
|
||
transforms.Resize((320, 320)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||
std=[0.229, 0.224, 0.225])
|
||
])
|
||
return transform(image).unsqueeze(0)
|
||
|
||
|
||
def postprocess(output, original_size):
|
||
output = output.squeeze().cpu().detach().numpy()
|
||
output = (output * 255).astype(np.uint8)
|
||
output = cv2.resize(output, original_size)
|
||
return output
|
||
|
||
|
||
def generate_saliency_map(input_image_path, output_image_path, model_path="./saved_models/u2netp.pth"):
|
||
net = load_u2net_model(model_path)
|
||
|
||
original_image = Image.open(input_image_path).convert("RGB")
|
||
original_size = original_image.size
|
||
input_tensor = preprocess(original_image).cuda()
|
||
|
||
output = net(input_tensor)[0]
|
||
saliency_map = postprocess(output, original_size)
|
||
|
||
cv2.imwrite(output_image_path, saliency_map)
|
||
print(f"显著图已保存至: {output_image_path}")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
triplets = [
|
||
# (标签, 原图路径, 显著 / 掩模 路径)
|
||
('front', './image/front.jpg', './saliency/front.png'), # 正面
|
||
('rear', './image/rear.jpg', './saliency/rear.png'), # 后面
|
||
('side', './image/side.jpg', './saliency/side.png'), # 侧面(做圆检测)
|
||
]
|
||
|
||
thresh_dir = './thresh'
|
||
os.makedirs(thresh_dir, exist_ok=True)
|
||
|
||
# # ======================= 生成显著性图 (可以注释掉,在u2net_saliency生成)=======================
|
||
for tag, image_path, saliency_path in triplets:
|
||
print(f"处理 {tag} 图像中...")
|
||
generate_saliency_map(image_path, saliency_path) |