Created
March 21, 2025 00:48
-
-
Save venetanji/ad4e866b74e5ab97837612cc140c0d5e to your computer and use it in GitHub Desktop.
Comfyui tool for crewai
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Type | |
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) | |
import uuid | |
import json | |
import urllib.request | |
import urllib.parse | |
from crewai.tools import BaseTool | |
from pydantic import BaseModel, Field | |
from datetime import datetime | |
from newstoimage.texts import comfy_flow_json | |
server_address = "localhost:8188" | |
class ImgGenParams(BaseModel): | |
"""Input schema for ImageGeneration tool.""" | |
prompt: str = Field(..., description="prompt for image generation") | |
def queue_prompt(prompt, client_id): | |
p = {"prompt": prompt, "client_id": client_id} | |
data = json.dumps(p).encode('utf-8') | |
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) | |
return json.loads(urllib.request.urlopen(req).read()) | |
def get_image(filename, subfolder, folder_type): | |
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | |
url_values = urllib.parse.urlencode(data) | |
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: | |
return response.read() | |
def get_history(prompt_id): | |
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: | |
return json.loads(response.read()) | |
def get_images(ws, prompt, client_id): | |
prompt_id = queue_prompt(prompt, client_id)['prompt_id'] | |
output_images = {} | |
current_node = "" | |
while True: | |
out = ws.recv() | |
if isinstance(out, str): | |
message = json.loads(out) | |
if message['type'] == 'executing': | |
data = message['data'] | |
if data['prompt_id'] == prompt_id: | |
if data['node'] is None: | |
break #Execution is done | |
else: | |
current_node = data['node'] | |
else: | |
if current_node == 'save_image_websocket_node': | |
images_output = output_images.get(current_node, []) | |
images_output.append(out[8:]) | |
# print(images_output) | |
output_images[current_node] = images_output | |
return output_images | |
class ImageGeneration(BaseTool): | |
name: str = "ImageGeneration tool" | |
description: str = ( | |
"This tool generates images starting from a text prompt" | |
) | |
args_schema: Type[BaseModel] = ImgGenParams | |
def _run(self, prompt: str) -> str: | |
# style = " black ink outline and cross-hatching style with vibrant cel-shaded coloring. Dramatic lighting and contrasts. Naturalistic features, western-style comic drawing. Realistic faces. " | |
client_id = str(uuid.uuid4()) | |
comfy_flow = json.loads(comfy_flow_json) | |
#set the text prompt for our positive CLIPTextEncode | |
# print(prompt) | |
comfy_flow["6"]["inputs"]["text"] = prompt #+ style | |
#set the seed for our KSampler node | |
# comfy_flow["31"]["inputs"]["seed"] = 5 | |
ws = websocket.WebSocket() | |
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) | |
images = get_images(ws, comfy_flow, client_id) | |
# print(images) | |
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts | |
#Commented out code to display the output images: | |
for node_id in images: | |
for image_data in images[node_id]: | |
from PIL import Image | |
import io | |
image = Image.open(io.BytesIO(image_data)) | |
image.show() | |
#image.save(....) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment