68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
|
# u2net_saliency_only.py
|
|||
|
import torch
|
|||
|
import cv2
|
|||
|
import torchvision.transforms as transforms
|
|||
|
import numpy as np
|
|||
|
from PIL import Image
|
|||
|
from model.u2net import U2NETP
|
|||
|
import os
|
|||
|
# 缓存模型,避免重复加载
|
|||
|
_model = None
|
|||
|
|
|||
|
|
|||
|
def load_u2net_model(model_path="./saved_models/u2netp.pth"):
|
|||
|
global _model
|
|||
|
if _model is None:
|
|||
|
_model = U2NETP(3, 1)
|
|||
|
_model.load_state_dict(torch.load(model_path))
|
|||
|
_model.cuda()
|
|||
|
_model.eval()
|
|||
|
return _model
|
|||
|
|
|||
|
|
|||
|
def preprocess(image):
|
|||
|
transform = transforms.Compose([
|
|||
|
transforms.Resize((320, 320)),
|
|||
|
transforms.ToTensor(),
|
|||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|||
|
std=[0.229, 0.224, 0.225])
|
|||
|
])
|
|||
|
return transform(image).unsqueeze(0)
|
|||
|
|
|||
|
|
|||
|
def postprocess(output, original_size):
|
|||
|
output = output.squeeze().cpu().detach().numpy()
|
|||
|
output = (output * 255).astype(np.uint8)
|
|||
|
output = cv2.resize(output, original_size)
|
|||
|
return output
|
|||
|
|
|||
|
|
|||
|
def generate_saliency_map(input_image_path, output_image_path, model_path="./saved_models/u2netp.pth"):
|
|||
|
net = load_u2net_model(model_path)
|
|||
|
|
|||
|
original_image = Image.open(input_image_path).convert("RGB")
|
|||
|
original_size = original_image.size
|
|||
|
input_tensor = preprocess(original_image).cuda()
|
|||
|
|
|||
|
output = net(input_tensor)[0]
|
|||
|
saliency_map = postprocess(output, original_size)
|
|||
|
|
|||
|
cv2.imwrite(output_image_path, saliency_map)
|
|||
|
print(f"显著图已保存至: {output_image_path}")
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
triplets = [
|
|||
|
# (标签, 原图路径, 显著 / 掩模 路径)
|
|||
|
('front', './image/front.jpg', './saliency/front.png'), # 正面
|
|||
|
('rear', './image/rear.jpg', './saliency/rear.png'), # 后面
|
|||
|
('side', './image/side.jpg', './saliency/side.png'), # 侧面(做圆检测)
|
|||
|
]
|
|||
|
|
|||
|
thresh_dir = './thresh'
|
|||
|
os.makedirs(thresh_dir, exist_ok=True)
|
|||
|
|
|||
|
# # ======================= 生成显著性图 (可以注释掉,在u2net_saliency生成)=======================
|
|||
|
for tag, image_path, saliency_path in triplets:
|
|||
|
print(f"处理 {tag} 图像中...")
|
|||
|
generate_saliency_map(image_path, saliency_path)
|