Skip to content

Instantly share code, notes, and snippets.

@RageshAntonyHM
Created December 29, 2024 16:54
Show Gist options
  • Save RageshAntonyHM/f4442b079c3a9eefca90d05924b24742 to your computer and use it in GitHub Desktop.
Save RageshAntonyHM/f4442b079c3a9eefca90d05924b24742 to your computer and use it in GitHub Desktop.
mutli-gen-huggging.py
import gradio as gr
import numpy as np
import random
import torch
from diffusers import (
DiffusionPipeline, FluxPipeline, PixArtSigmaPipeline,
AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline,
LuminaText2ImgPipeline
)
import spaces
import gc
import os
import psutil
import threading
from pathlib import Path
import shutil
import time
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Model configurations
MODEL_CONFIGS = {
"Stable Diffusion 3.5": {
"repo_id": "stabilityai/stable-diffusion-3.5-large",
"pipeline_class": DiffusionPipeline
},
"FLUX": {
"repo_id": "black-forest-labs/FLUX.1-dev",
"pipeline_class": FluxPipeline
},
"PixArt": {
"repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
"pipeline_class": PixArtSigmaPipeline
},
"AuraFlow": {
"repo_id": "fal/AuraFlow",
"pipeline_class": AuraFlowPipeline
},
"Kandinsky": {
"repo_id": "kandinsky-community/kandinsky-3",
"pipeline_class": Kandinsky3Pipeline
},
"Hunyuan": {
"repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers",
"pipeline_class": HunyuanDiTPipeline
},
"Lumina": {
"repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers",
"pipeline_class": LuminaText2ImgPipeline
}
}
# Dictionary to store model pipelines
pipes = {}
model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()}
def get_process_memory():
"""Get memory usage of current process in GB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024 / 1024
def clear_torch_cache():
"""Clear PyTorch's CUDA cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def remove_cache_dir(model_name):
"""Remove the model's cache directory"""
cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/', '--')
if cache_dir.exists():
shutil.rmtree(cache_dir, ignore_errors=True)
def deep_cleanup(model_name, pipe):
"""Perform deep cleanup of model resources"""
try:
# 1. Move model to CPU first (helps prevent CUDA memory fragmentation)
if hasattr(pipe, 'to'):
pipe.to('cpu')
# 2. Delete all model components explicitly
for attr_name in list(pipe.__dict__.keys()):
if hasattr(pipe, attr_name):
delattr(pipe, attr_name)
# 3. Remove from pipes dictionary
if model_name in pipes:
del pipes[model_name]
# 4. Clear CUDA cache
clear_torch_cache()
# 5. Run garbage collection multiple times
for _ in range(3):
gc.collect()
# 6. Remove cached files
remove_cache_dir(model_name)
# 7. Additional CUDA cleanup if available
if torch.cuda.is_available():
torch.cuda.synchronize()
# 8. Wait a small amount of time to ensure cleanup
time.sleep(1)
except Exception as e:
print(f"Error during cleanup of {model_name}: {str(e)}")
finally:
# Final garbage collection
gc.collect()
clear_torch_cache()
def load_pipeline(model_name):
"""Load model pipeline with memory tracking"""
initial_memory = get_process_memory()
config = MODEL_CONFIGS[model_name]
pipe = config["pipeline_class"].from_pretrained(
config["repo_id"],
torch_dtype=TORCH_DTYPE
)
pipe = pipe.to(DEVICE)
if hasattr(pipe, 'enable_model_cpu_offload'):
pipe.enable_model_cpu_offload()
final_memory = get_process_memory()
print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB")
return pipe
#@spaces.GPU(duration=180)
def generate_image(
model_name,
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=4.5,
num_inference_steps=40,
progress=gr.Progress(track_tqdm=True)
):
with model_locks[model_name]:
try:
progress(0, desc=f"Loading {model_name} model...")
# Load model if not already loaded
if model_name not in pipes:
pipes[model_name] = load_pipeline(model_name)
pipe = pipes[model_name]
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(DEVICE).manual_seed(seed)
progress(0.3, desc=f"Generating image with {model_name}...")
# Generate image
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
progress(0.9, desc=f"Cleaning up {model_name} resources...")
# Cleanup after generation
deep_cleanup(model_name, pipe)
progress(1.0, desc=f"Generation complete with {model_name}")
return image, seed
except Exception as e:
print(f"Error with {model_name}: {str(e)}")
# Ensure cleanup happens even if generation fails
if model_name in pipes:
deep_cleanup(model_name, pipes[model_name])
raise e
# Gradio Interface
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
#run_test_safe.zerogpu = True
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Multi-Model Image Generation")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Generate", scale=0, variant="primary")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=4.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40,
)
# Memory usage indicator
memory_indicator = gr.Markdown("Current memory usage: 0 GB")
# Create tabs for each model
with gr.Tabs() as tabs:
results = {}
seeds = {}
for model_name in MODEL_CONFIGS.keys():
with gr.Tab(model_name):
results[model_name] = gr.Image(label=f"{model_name} Result")
seeds[model_name] = gr.Number(label="Seed used", visible=False)
examples = [
"A capybara wearing a suit holding a sign that reads Hello World",
"A serene landscape with mountains and a lake at sunset",
]
gr.Examples(examples=examples, inputs=[prompt])
def update_memory_usage():
"""Update memory usage display"""
memory_gb = get_process_memory()
if torch.cuda.is_available():
cuda_memory_gb = torch.cuda.memory_allocated() / 1024 / 1024 / 1024
return f"Current memory usage: System RAM: {memory_gb:.2f} GB, CUDA: {cuda_memory_gb:.2f} GB"
return f"Current memory usage: System RAM: {memory_gb:.2f} GB"
# Handle generation for each model
@spaces.GPU(duration=180)
def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
outputs = []
for model_name in MODEL_CONFIGS.keys():
try:
image, used_seed = generate_image(
model_name, prompt, negative_prompt, seed,
randomize_seed, width, height, guidance_scale,
num_inference_steps, progress
)
outputs.extend([image, used_seed])
# Update memory usage after each model
memory_indicator.update(update_memory_usage())
except Exception as e:
outputs.extend([None, None])
print(f"Error generating with {model_name}: {str(e)}")
return outputs
# Set up the generation trigger
output_components = []
for model_name in MODEL_CONFIGS.keys():
output_components.extend([results[model_name], seeds[model_name]])
run_button.click(
fn=generate_all,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=output_components,
)
if __name__ == "__main__":
demo.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment