Created
December 29, 2024 16:54
-
-
Save RageshAntonyHM/f4442b079c3a9eefca90d05924b24742 to your computer and use it in GitHub Desktop.
mutli-gen-huggging.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from diffusers import ( | |
DiffusionPipeline, FluxPipeline, PixArtSigmaPipeline, | |
AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline, | |
LuminaText2ImgPipeline | |
) | |
import spaces | |
import gc | |
import os | |
import psutil | |
import threading | |
from pathlib import Path | |
import shutil | |
import time | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
# Model configurations | |
MODEL_CONFIGS = { | |
"Stable Diffusion 3.5": { | |
"repo_id": "stabilityai/stable-diffusion-3.5-large", | |
"pipeline_class": DiffusionPipeline | |
}, | |
"FLUX": { | |
"repo_id": "black-forest-labs/FLUX.1-dev", | |
"pipeline_class": FluxPipeline | |
}, | |
"PixArt": { | |
"repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", | |
"pipeline_class": PixArtSigmaPipeline | |
}, | |
"AuraFlow": { | |
"repo_id": "fal/AuraFlow", | |
"pipeline_class": AuraFlowPipeline | |
}, | |
"Kandinsky": { | |
"repo_id": "kandinsky-community/kandinsky-3", | |
"pipeline_class": Kandinsky3Pipeline | |
}, | |
"Hunyuan": { | |
"repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers", | |
"pipeline_class": HunyuanDiTPipeline | |
}, | |
"Lumina": { | |
"repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers", | |
"pipeline_class": LuminaText2ImgPipeline | |
} | |
} | |
# Dictionary to store model pipelines | |
pipes = {} | |
model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()} | |
def get_process_memory(): | |
"""Get memory usage of current process in GB""" | |
process = psutil.Process(os.getpid()) | |
return process.memory_info().rss / 1024 / 1024 / 1024 | |
def clear_torch_cache(): | |
"""Clear PyTorch's CUDA cache""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def remove_cache_dir(model_name): | |
"""Remove the model's cache directory""" | |
cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/', '--') | |
if cache_dir.exists(): | |
shutil.rmtree(cache_dir, ignore_errors=True) | |
def deep_cleanup(model_name, pipe): | |
"""Perform deep cleanup of model resources""" | |
try: | |
# 1. Move model to CPU first (helps prevent CUDA memory fragmentation) | |
if hasattr(pipe, 'to'): | |
pipe.to('cpu') | |
# 2. Delete all model components explicitly | |
for attr_name in list(pipe.__dict__.keys()): | |
if hasattr(pipe, attr_name): | |
delattr(pipe, attr_name) | |
# 3. Remove from pipes dictionary | |
if model_name in pipes: | |
del pipes[model_name] | |
# 4. Clear CUDA cache | |
clear_torch_cache() | |
# 5. Run garbage collection multiple times | |
for _ in range(3): | |
gc.collect() | |
# 6. Remove cached files | |
remove_cache_dir(model_name) | |
# 7. Additional CUDA cleanup if available | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
# 8. Wait a small amount of time to ensure cleanup | |
time.sleep(1) | |
except Exception as e: | |
print(f"Error during cleanup of {model_name}: {str(e)}") | |
finally: | |
# Final garbage collection | |
gc.collect() | |
clear_torch_cache() | |
def load_pipeline(model_name): | |
"""Load model pipeline with memory tracking""" | |
initial_memory = get_process_memory() | |
config = MODEL_CONFIGS[model_name] | |
pipe = config["pipeline_class"].from_pretrained( | |
config["repo_id"], | |
torch_dtype=TORCH_DTYPE | |
) | |
pipe = pipe.to(DEVICE) | |
if hasattr(pipe, 'enable_model_cpu_offload'): | |
pipe.enable_model_cpu_offload() | |
final_memory = get_process_memory() | |
print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB") | |
return pipe | |
#@spaces.GPU(duration=180) | |
def generate_image( | |
model_name, | |
prompt, | |
negative_prompt="", | |
seed=42, | |
randomize_seed=False, | |
width=1024, | |
height=1024, | |
guidance_scale=4.5, | |
num_inference_steps=40, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
with model_locks[model_name]: | |
try: | |
progress(0, desc=f"Loading {model_name} model...") | |
# Load model if not already loaded | |
if model_name not in pipes: | |
pipes[model_name] = load_pipeline(model_name) | |
pipe = pipes[model_name] | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(DEVICE).manual_seed(seed) | |
progress(0.3, desc=f"Generating image with {model_name}...") | |
# Generate image | |
image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
progress(0.9, desc=f"Cleaning up {model_name} resources...") | |
# Cleanup after generation | |
deep_cleanup(model_name, pipe) | |
progress(1.0, desc=f"Generation complete with {model_name}") | |
return image, seed | |
except Exception as e: | |
print(f"Error with {model_name}: {str(e)}") | |
# Ensure cleanup happens even if generation fails | |
if model_name in pipes: | |
deep_cleanup(model_name, pipes[model_name]) | |
raise e | |
# Gradio Interface | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 1024px; | |
} | |
""" | |
#run_test_safe.zerogpu = True | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# Multi-Model Image Generation") | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Generate", scale=0, variant="primary") | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=512, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=512, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=7.5, | |
step=0.1, | |
value=4.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=40, | |
) | |
# Memory usage indicator | |
memory_indicator = gr.Markdown("Current memory usage: 0 GB") | |
# Create tabs for each model | |
with gr.Tabs() as tabs: | |
results = {} | |
seeds = {} | |
for model_name in MODEL_CONFIGS.keys(): | |
with gr.Tab(model_name): | |
results[model_name] = gr.Image(label=f"{model_name} Result") | |
seeds[model_name] = gr.Number(label="Seed used", visible=False) | |
examples = [ | |
"A capybara wearing a suit holding a sign that reads Hello World", | |
"A serene landscape with mountains and a lake at sunset", | |
] | |
gr.Examples(examples=examples, inputs=[prompt]) | |
def update_memory_usage(): | |
"""Update memory usage display""" | |
memory_gb = get_process_memory() | |
if torch.cuda.is_available(): | |
cuda_memory_gb = torch.cuda.memory_allocated() / 1024 / 1024 / 1024 | |
return f"Current memory usage: System RAM: {memory_gb:.2f} GB, CUDA: {cuda_memory_gb:.2f} GB" | |
return f"Current memory usage: System RAM: {memory_gb:.2f} GB" | |
# Handle generation for each model | |
@spaces.GPU(duration=180) | |
def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()): | |
outputs = [] | |
for model_name in MODEL_CONFIGS.keys(): | |
try: | |
image, used_seed = generate_image( | |
model_name, prompt, negative_prompt, seed, | |
randomize_seed, width, height, guidance_scale, | |
num_inference_steps, progress | |
) | |
outputs.extend([image, used_seed]) | |
# Update memory usage after each model | |
memory_indicator.update(update_memory_usage()) | |
except Exception as e: | |
outputs.extend([None, None]) | |
print(f"Error generating with {model_name}: {str(e)}") | |
return outputs | |
# Set up the generation trigger | |
output_components = [] | |
for model_name in MODEL_CONFIGS.keys(): | |
output_components.extend([results[model_name], seeds[model_name]]) | |
run_button.click( | |
fn=generate_all, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=output_components, | |
) | |
if __name__ == "__main__": | |
demo.launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment