Skip to content

Instantly share code, notes, and snippets.

@venetanji
Created March 21, 2025 00:48
Show Gist options
  • Save venetanji/ad4e866b74e5ab97837612cc140c0d5e to your computer and use it in GitHub Desktop.
Save venetanji/ad4e866b74e5ab97837612cc140c0d5e to your computer and use it in GitHub Desktop.
Comfyui tool for crewai
from typing import Type
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
from datetime import datetime
from newstoimage.texts import comfy_flow_json
server_address = "localhost:8188"
class ImgGenParams(BaseModel):
"""Input schema for ImageGeneration tool."""
prompt: str = Field(..., description="prompt for image generation")
def queue_prompt(prompt, client_id):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
def get_images(ws, prompt, client_id):
prompt_id = queue_prompt(prompt, client_id)['prompt_id']
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['prompt_id'] == prompt_id:
if data['node'] is None:
break #Execution is done
else:
current_node = data['node']
else:
if current_node == 'save_image_websocket_node':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
# print(images_output)
output_images[current_node] = images_output
return output_images
class ImageGeneration(BaseTool):
name: str = "ImageGeneration tool"
description: str = (
"This tool generates images starting from a text prompt"
)
args_schema: Type[BaseModel] = ImgGenParams
def _run(self, prompt: str) -> str:
# style = " black ink outline and cross-hatching style with vibrant cel-shaded coloring. Dramatic lighting and contrasts. Naturalistic features, western-style comic drawing. Realistic faces. "
client_id = str(uuid.uuid4())
comfy_flow = json.loads(comfy_flow_json)
#set the text prompt for our positive CLIPTextEncode
# print(prompt)
comfy_flow["6"]["inputs"]["text"] = prompt #+ style
#set the seed for our KSampler node
# comfy_flow["31"]["inputs"]["seed"] = 5
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, comfy_flow, client_id)
# print(images)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images:
for node_id in images:
for image_data in images[node_id]:
from PIL import Image
import io
image = Image.open(io.BytesIO(image_data))
image.show()
#image.save(....)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment