import argparse import os import cv2 import numpy as np import torch from segment_anything import sam_model_registry, SamPredictor def main(): parser = argparse.ArgumentParser(description="SAM box prompt segmentation 批量处理") parser.add_argument("--input-dir", type=str, required=True, help="输入图片文件夹路径") parser.add_argument("--box-file", type=str, required=True, help="box信息文件路径,每行: 图片名 x1 y1 x2 y2") parser.add_argument("--checkpoint", type=str, required=True, help="SAM模型权重路径") parser.add_argument("--model-type", type=str, default="vit_h", help="模型类型,默认vit_h") parser.add_argument("--output-dir", type=str, required=True, help="输出掩码图片文件夹路径") args = parser.parse_args() # 检查CUDA device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": print("警告:未检测到CUDA,正在使用CPU,速度会较慢。") # 加载模型 sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) sam.to(device) predictor = SamPredictor(sam) # 创建输出文件夹 os.makedirs(args.output_dir, exist_ok=True) # 读取box文件并批量处理 with open(args.box_file, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue parts = line.split() if len(parts) != 5: print(f"格式错误: {line}") continue img_name, x1, y1, x2, y2 = parts box = np.array([int(x1), int(y1), int(x2), int(y2)]) img_path = os.path.join(args.input_dir, img_name) image = cv2.imread(img_path) if image is None: print(f"无法读取图片: {img_path}") continue image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(image_rgb) masks, scores, logits = predictor.predict(box=box[None, :], multimask_output=False) mask = masks[0].astype(np.uint8) * 255 out_path = os.path.join(args.output_dir, f"{os.path.splitext(img_name)[0]}_mask.png") cv2.imwrite(out_path, mask) print(f"掩码已保存到: {out_path}") if __name__ == "__main__": main()