Created
December 3, 2023 17:26
-
-
Save city96/103c394ef9cf9300aca67d1c2a2d28b5 to your computer and use it in GitHub Desktop.
ComfyUI anime segmentation custom node
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Simple custom node to segment anime images using https://github.com/SkyTNT/anime-segmentation | |
# To install the custom node, copy this file to your `ComfyUI/custom_nodes` folder | |
# To install the requirements, run `pip install onnxruntime huggingface-hub` inside your VENV | |
# If using the standalone, navigate to the folder where your .bat file is and run this command: | |
# .\python_embeded\python.exe -m pip install onnxruntime huggingface-hub | |
# | |
import torch | |
import numpy as np | |
from PIL import Image | |
from onnxruntime import InferenceSession | |
from huggingface_hub import hf_hub_download | |
class AnimeSeg: | |
def __init__(self): | |
pass | |
@classmethod | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
} | |
} | |
RETURN_TYPES = ("MASK",) | |
RETURN_NAMES = ("mask",) | |
FUNCTION = "segment" | |
CATEGORY = "bootleg" | |
TITLE = "Anime Segmentation" | |
def get_mask(self, src): | |
raw_size = src.size | |
seg_size = self.session.get_inputs()[0].shape[2:4] | |
src.thumbnail(seg_size) | |
dst_size = int((seg_size[0]-src.size[0])/2), int((seg_size[1]-src.size[1])/2) | |
img = Image.new('RGB',size=seg_size,color=(0,0,0)) | |
img.paste(src, dst_size) | |
img = np.array(img) | |
img = img[:, :, ::-1] # PIL RGB to OpenCV BGR | |
img = img.transpose((2, 0, 1)) # N, C, W, H | |
img = img.astype(np.float32) / 255.0 | |
img = np.expand_dims(img, 0) | |
in_name = self.session.get_inputs()[0].name | |
out_name = self.session.get_outputs()[0].name | |
mask = self.session.run([out_name], {in_name: img})[0] | |
mask = torch.clamp(torch.from_numpy(mask), 0.0, 1.0).transpose(3, 2) | |
mask = mask[:, :, dst_size[0]:(seg_size[0]-dst_size[0]), dst_size[1]:seg_size[1]-dst_size[1]] | |
mask = torch.nn.functional.interpolate(mask, raw_size, mode="bilinear") | |
return mask.transpose(3, 2) | |
def segment(self, image): | |
self.session = InferenceSession( | |
str(hf_hub_download(repo_id="skytnt/anime-seg", filename="isnetis.onnx")), | |
providers=["CPUExecutionProvider"] | |
) | |
img = Image.fromarray((image[0] * 255.0).to(torch.uint8).numpy(), mode='RGB') | |
mask = self.get_mask(img) | |
del self.session | |
return (mask,) | |
NODE_CLASS_MAPPINGS = { "SimpleAnimeSeg": AnimeSeg} | |
NODE_DISPLAY_NAME_MAPPINGS = {"SimpleAnimeSeg": AnimeSeg.TITLE} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment