image_to_pixle_params_yoloSAM/main/u2net_saliency.py

68 lines
2.1 KiB
Python
Raw Normal View History

2025-07-14 17:36:53 +08:00
# u2net_saliency_only.py
import torch
import cv2
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from model.u2net import U2NETP
import os
# 缓存模型,避免重复加载
_model = None
def load_u2net_model(model_path="./saved_models/u2netp.pth"):
global _model
if _model is None:
_model = U2NETP(3, 1)
_model.load_state_dict(torch.load(model_path))
_model.cuda()
_model.eval()
return _model
def preprocess(image):
transform = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
def postprocess(output, original_size):
output = output.squeeze().cpu().detach().numpy()
output = (output * 255).astype(np.uint8)
output = cv2.resize(output, original_size)
return output
def generate_saliency_map(input_image_path, output_image_path, model_path="./saved_models/u2netp.pth"):
net = load_u2net_model(model_path)
original_image = Image.open(input_image_path).convert("RGB")
original_size = original_image.size
input_tensor = preprocess(original_image).cuda()
output = net(input_tensor)[0]
saliency_map = postprocess(output, original_size)
cv2.imwrite(output_image_path, saliency_map)
print(f"显著图已保存至: {output_image_path}")
if __name__ == '__main__':
triplets = [
# (标签, 原图路径, 显著 / 掩模 路径)
('front', './image/front.jpg', './saliency/front.png'), # 正面
('rear', './image/rear.jpg', './saliency/rear.png'), # 后面
('side', './image/side.jpg', './saliency/side.png'), # 侧面(做圆检测)
]
thresh_dir = './thresh'
os.makedirs(thresh_dir, exist_ok=True)
# # ======================= 生成显著性图 可以注释掉在u2net_saliency生成=======================
for tag, image_path, saliency_path in triplets:
print(f"处理 {tag} 图像中...")
generate_saliency_map(image_path, saliency_path)