Skip to content

Instantly share code, notes, and snippets.

@nagadomi
Last active January 21, 2026 02:10
Show Gist options
  • Select an option

  • Save nagadomi/c9fdfd6febdb852cba873656c4321d82 to your computer and use it in GitHub Desktop.

Select an option

Save nagadomi/c9fdfd6febdb852cba873656c4321d82 to your computer and use it in GitHub Desktop.
generates human masks with SAM3
# python generate_person_mask_sam3.py -i ./test_images -o output
#
# This script generates human masks from images.
# It works not only with photos but also with anime-style images.
# The latest version of transformers is required.
# `pip install git+https://github.com/huggingface/transformers.git`
# Using SAM3 requires proper license authorization.
import os
from os import path
from tqdm import tqdm
import argparse
from PIL import Image
import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from transformers import Sam3Processor, Sam3Model
def dilate(mask, kernel_size=3):
if isinstance(kernel_size, (list, tuple)):
pad = (kernel_size[0] // 2, kernel_size[1] // 2)
else:
pad = kernel_size // 2
return F.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=pad)
def erode(mask, kernel_size=3):
if isinstance(kernel_size, (list, tuple)):
pad = (kernel_size[0] // 2, kernel_size[1] // 2)
else:
pad = kernel_size // 2
return -F.max_pool2d(-mask, kernel_size=kernel_size, stride=1, padding=pad)
def closing(mask, kernel_size=3, n_iter=3):
mask = mask.float()
for _ in range(n_iter):
mask = dilate(mask, kernel_size=kernel_size)
for _ in range(n_iter):
mask = erode(mask, kernel_size=kernel_size)
return mask
def opening(mask, kernel_size=3, n_iter=3):
mask = mask.float()
for _ in range(n_iter):
mask = erode(mask, kernel_size=kernel_size)
for _ in range(n_iter):
mask = dilate(mask, kernel_size=kernel_size)
return mask
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", type=str, required=True, help="input dir")
parser.add_argument("--output", "-o", type=str, required=True, help="output dir")
parser.add_argument("--prompt", "-p", type=str, default="person", help="text prompt")
parser.add_argument("--threshold", type=float, default=0.3, help="threshold")
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps", "xpu"], help="device")
parser.add_argument("--rgba", action="store_true", help="output RGBA image")
parser.add_argument("--skip-clean", action="store_true", help="skip mask cleaning")
args = parser.parse_args()
os.makedirs(args.output, exist_ok=True)
IMAGE_EXT = {".png", ".jpg", ".jpeg", ".webp"}
device = torch.device(args.device)
model = Sam3Model.from_pretrained("facebook/sam3").to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
for fn in tqdm(os.listdir(args.input)):
if path.splitext(fn)[-1].lower() not in IMAGE_EXT:
continue
input_file = path.join(args.input, fn)
with Image.open(input_file) as im:
im.load()
im = im.convert("RGB")
inputs = processor(images=im, text=args.prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(
outputs,
threshold=args.threshold,
mask_threshold=args.threshold,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
mask = torch.zeros((1, im.height, im.width), dtype=torch.float32).to(device)
if len(results["masks"]) > 0:
for m in results["masks"]:
m = m.unsqueeze(0).float()
if not args.skip_clean:
m = opening(m)
m = closing(m)
mask += m
mask = mask.clamp(0, 1)
output_filename = path.join(args.output, path.splitext(fn)[0] + ".png")
if args.rgba:
# RGBA image
im.putalpha(TF.to_pil_image(mask))
im.save(output_filename)
else:
# Grayscale mask
TF.to_pil_image(mask).save(output_filename)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment