57 lines
2.3 KiB
Python
57 lines
2.3 KiB
Python
import argparse
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from segment_anything import sam_model_registry, SamPredictor
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="SAM box prompt segmentation 批量处理")
|
||
parser.add_argument("--input-dir", type=str, required=True, help="输入图片文件夹路径")
|
||
parser.add_argument("--box-file", type=str, required=True, help="box信息文件路径,每行: 图片名 x1 y1 x2 y2")
|
||
parser.add_argument("--checkpoint", type=str, required=True, help="SAM模型权重路径")
|
||
parser.add_argument("--model-type", type=str, default="vit_h", help="模型类型,默认vit_h")
|
||
parser.add_argument("--output-dir", type=str, required=True, help="输出掩码图片文件夹路径")
|
||
args = parser.parse_args()
|
||
|
||
# 检查CUDA
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
if device == "cpu":
|
||
print("警告:未检测到CUDA,正在使用CPU,速度会较慢。")
|
||
|
||
# 加载模型
|
||
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
||
sam.to(device)
|
||
predictor = SamPredictor(sam)
|
||
|
||
# 创建输出文件夹
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
# 读取box文件并批量处理
|
||
with open(args.box_file, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
parts = line.split()
|
||
if len(parts) != 5:
|
||
print(f"格式错误: {line}")
|
||
continue
|
||
img_name, x1, y1, x2, y2 = parts
|
||
box = np.array([int(x1), int(y1), int(x2), int(y2)])
|
||
img_path = os.path.join(args.input_dir, img_name)
|
||
image = cv2.imread(img_path)
|
||
if image is None:
|
||
print(f"无法读取图片: {img_path}")
|
||
continue
|
||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
predictor.set_image(image_rgb)
|
||
masks, scores, logits = predictor.predict(box=box[None, :], multimask_output=False)
|
||
mask = masks[0].astype(np.uint8) * 255
|
||
out_path = os.path.join(args.output_dir, f"{os.path.splitext(img_name)[0]}_mask.png")
|
||
cv2.imwrite(out_path, mask)
|
||
print(f"掩码已保存到: {out_path}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |