image_to_pixle_params_yoloSAM/segment-anything-main/test_box.py

57 lines
2.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import os
import cv2
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor
def main():
parser = argparse.ArgumentParser(description="SAM box prompt segmentation 批量处理")
parser.add_argument("--input-dir", type=str, required=True, help="输入图片文件夹路径")
parser.add_argument("--box-file", type=str, required=True, help="box信息文件路径每行: 图片名 x1 y1 x2 y2")
parser.add_argument("--checkpoint", type=str, required=True, help="SAM模型权重路径")
parser.add_argument("--model-type", type=str, default="vit_h", help="模型类型默认vit_h")
parser.add_argument("--output-dir", type=str, required=True, help="输出掩码图片文件夹路径")
args = parser.parse_args()
# 检查CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("警告未检测到CUDA正在使用CPU速度会较慢。")
# 加载模型
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
sam.to(device)
predictor = SamPredictor(sam)
# 创建输出文件夹
os.makedirs(args.output_dir, exist_ok=True)
# 读取box文件并批量处理
with open(args.box_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) != 5:
print(f"格式错误: {line}")
continue
img_name, x1, y1, x2, y2 = parts
box = np.array([int(x1), int(y1), int(x2), int(y2)])
img_path = os.path.join(args.input_dir, img_name)
image = cv2.imread(img_path)
if image is None:
print(f"无法读取图片: {img_path}")
continue
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image_rgb)
masks, scores, logits = predictor.predict(box=box[None, :], multimask_output=False)
mask = masks[0].astype(np.uint8) * 255
out_path = os.path.join(args.output_dir, f"{os.path.splitext(img_name)[0]}_mask.png")
cv2.imwrite(out_path, mask)
print(f"掩码已保存到: {out_path}")
if __name__ == "__main__":
main()