Last active
January 21, 2026 02:10
-
-
Save nagadomi/c9fdfd6febdb852cba873656c4321d82 to your computer and use it in GitHub Desktop.
generates human masks with SAM3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # python generate_person_mask_sam3.py -i ./test_images -o output | |
| # | |
| # This script generates human masks from images. | |
| # It works not only with photos but also with anime-style images. | |
| # The latest version of transformers is required. | |
| # `pip install git+https://github.com/huggingface/transformers.git` | |
| # Using SAM3 requires proper license authorization. | |
| import os | |
| from os import path | |
| from tqdm import tqdm | |
| import argparse | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| import torch.nn.functional as F | |
| from transformers import Sam3Processor, Sam3Model | |
| def dilate(mask, kernel_size=3): | |
| if isinstance(kernel_size, (list, tuple)): | |
| pad = (kernel_size[0] // 2, kernel_size[1] // 2) | |
| else: | |
| pad = kernel_size // 2 | |
| return F.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=pad) | |
| def erode(mask, kernel_size=3): | |
| if isinstance(kernel_size, (list, tuple)): | |
| pad = (kernel_size[0] // 2, kernel_size[1] // 2) | |
| else: | |
| pad = kernel_size // 2 | |
| return -F.max_pool2d(-mask, kernel_size=kernel_size, stride=1, padding=pad) | |
| def closing(mask, kernel_size=3, n_iter=3): | |
| mask = mask.float() | |
| for _ in range(n_iter): | |
| mask = dilate(mask, kernel_size=kernel_size) | |
| for _ in range(n_iter): | |
| mask = erode(mask, kernel_size=kernel_size) | |
| return mask | |
| def opening(mask, kernel_size=3, n_iter=3): | |
| mask = mask.float() | |
| for _ in range(n_iter): | |
| mask = erode(mask, kernel_size=kernel_size) | |
| for _ in range(n_iter): | |
| mask = dilate(mask, kernel_size=kernel_size) | |
| return mask | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input", "-i", type=str, required=True, help="input dir") | |
| parser.add_argument("--output", "-o", type=str, required=True, help="output dir") | |
| parser.add_argument("--prompt", "-p", type=str, default="person", help="text prompt") | |
| parser.add_argument("--threshold", type=float, default=0.3, help="threshold") | |
| parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps", "xpu"], help="device") | |
| parser.add_argument("--rgba", action="store_true", help="output RGBA image") | |
| parser.add_argument("--skip-clean", action="store_true", help="skip mask cleaning") | |
| args = parser.parse_args() | |
| os.makedirs(args.output, exist_ok=True) | |
| IMAGE_EXT = {".png", ".jpg", ".jpeg", ".webp"} | |
| device = torch.device(args.device) | |
| model = Sam3Model.from_pretrained("facebook/sam3").to(device) | |
| processor = Sam3Processor.from_pretrained("facebook/sam3") | |
| for fn in tqdm(os.listdir(args.input)): | |
| if path.splitext(fn)[-1].lower() not in IMAGE_EXT: | |
| continue | |
| input_file = path.join(args.input, fn) | |
| with Image.open(input_file) as im: | |
| im.load() | |
| im = im.convert("RGB") | |
| inputs = processor(images=im, text=args.prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| results = processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=args.threshold, | |
| mask_threshold=args.threshold, | |
| target_sizes=inputs.get("original_sizes").tolist() | |
| )[0] | |
| mask = torch.zeros((1, im.height, im.width), dtype=torch.float32).to(device) | |
| if len(results["masks"]) > 0: | |
| for m in results["masks"]: | |
| m = m.unsqueeze(0).float() | |
| if not args.skip_clean: | |
| m = opening(m) | |
| m = closing(m) | |
| mask += m | |
| mask = mask.clamp(0, 1) | |
| output_filename = path.join(args.output, path.splitext(fn)[0] + ".png") | |
| if args.rgba: | |
| # RGBA image | |
| im.putalpha(TF.to_pil_image(mask)) | |
| im.save(output_filename) | |
| else: | |
| # Grayscale mask | |
| TF.to_pil_image(mask).save(output_filename) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment