image_to_pixle_params_yoloSAM/segment-anything-main/test_box.py

57 lines
2.3 KiB
Python
Raw Normal View History

2025-07-14 17:36:53 +08:00
import argparse
import os
import cv2
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor
def main():
parser = argparse.ArgumentParser(description="SAM box prompt segmentation 批量处理")
parser.add_argument("--input-dir", type=str, required=True, help="输入图片文件夹路径")
parser.add_argument("--box-file", type=str, required=True, help="box信息文件路径每行: 图片名 x1 y1 x2 y2")
parser.add_argument("--checkpoint", type=str, required=True, help="SAM模型权重路径")
parser.add_argument("--model-type", type=str, default="vit_h", help="模型类型默认vit_h")
parser.add_argument("--output-dir", type=str, required=True, help="输出掩码图片文件夹路径")
args = parser.parse_args()
# 检查CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("警告未检测到CUDA正在使用CPU速度会较慢。")
# 加载模型
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
sam.to(device)
predictor = SamPredictor(sam)
# 创建输出文件夹
os.makedirs(args.output_dir, exist_ok=True)
# 读取box文件并批量处理
with open(args.box_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) != 5:
print(f"格式错误: {line}")
continue
img_name, x1, y1, x2, y2 = parts
box = np.array([int(x1), int(y1), int(x2), int(y2)])
img_path = os.path.join(args.input_dir, img_name)
image = cv2.imread(img_path)
if image is None:
print(f"无法读取图片: {img_path}")
continue
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image_rgb)
masks, scores, logits = predictor.predict(box=box[None, :], multimask_output=False)
mask = masks[0].astype(np.uint8) * 255
out_path = os.path.join(args.output_dir, f"{os.path.splitext(img_name)[0]}_mask.png")
cv2.imwrite(out_path, mask)
print(f"掩码已保存到: {out_path}")
if __name__ == "__main__":
main()