image_to_pixle_params_yoloSAM/main/u2net_saliency.py

68 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# u2net_saliency_only.py
import torch
import cv2
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from model.u2net import U2NETP
import os
# 缓存模型,避免重复加载
_model = None
def load_u2net_model(model_path="./saved_models/u2netp.pth"):
global _model
if _model is None:
_model = U2NETP(3, 1)
_model.load_state_dict(torch.load(model_path))
_model.cuda()
_model.eval()
return _model
def preprocess(image):
transform = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
def postprocess(output, original_size):
output = output.squeeze().cpu().detach().numpy()
output = (output * 255).astype(np.uint8)
output = cv2.resize(output, original_size)
return output
def generate_saliency_map(input_image_path, output_image_path, model_path="./saved_models/u2netp.pth"):
net = load_u2net_model(model_path)
original_image = Image.open(input_image_path).convert("RGB")
original_size = original_image.size
input_tensor = preprocess(original_image).cuda()
output = net(input_tensor)[0]
saliency_map = postprocess(output, original_size)
cv2.imwrite(output_image_path, saliency_map)
print(f"显著图已保存至: {output_image_path}")
if __name__ == '__main__':
triplets = [
# (标签, 原图路径, 显著 / 掩模 路径)
('front', './image/front.jpg', './saliency/front.png'), # 正面
('rear', './image/rear.jpg', './saliency/rear.png'), # 后面
('side', './image/side.jpg', './saliency/side.png'), # 侧面(做圆检测)
]
thresh_dir = './thresh'
os.makedirs(thresh_dir, exist_ok=True)
# # ======================= 生成显著性图 可以注释掉在u2net_saliency生成=======================
for tag, image_path, saliency_path in triplets:
print(f"处理 {tag} 图像中...")
generate_saliency_map(image_path, saliency_path)