Last active
March 21, 2025 09:14
-
-
Save josephrocca/385d9868ac52ea6f854b3ab96ec0ad25 to your computer and use it in GitHub Desktop.
Chroma Diffusers Inference Code (working, but *unoptimized* - about 30s on 4090 for 1024px)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# docker run --rm -it --gpus all -v /home/USER/Desktop/chroma:/workspace -w /workspace pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel bash | |
# export HF_HOME=/workspace/cache/huggingface | |
# pip3 install diffusers==0.32.2 transformers==4.49.0 para-attn==0.3.22 pillow==11.1.0 optimum-quanto==0.2.7 scipy==1.15.2 sentencepiece==0.2.0 protobuf==6.30.1 accelerate==1.5.2 | |
# # Apply some patches to diffusers. First is a bugfix (next diffusers release will have this fix), the rest are to allow negative prompts for Chroma, since diffusers assumes that positive and negative prompts are same length: https://github.com/huggingface/diffusers/pull/11120 | |
# DIFFUSERS_PATH=$(python -c "import diffusers; print(diffusers.__path__[0])") | |
# sed -i 's/do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None$/do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None or ( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None )/' "$DIFFUSERS_PATH/pipelines/flux/pipeline_flux.py" | |
# sed -i '/ if do_true_cfg:/,$ {0,/_,/s//negative_text_ids,/}' "$DIFFUSERS_PATH/pipelines/flux/pipeline_flux.py" | |
# sed -i '/encoder_hidden_states=negative_prompt_embeds/,$ {0,/txt_ids=text_ids/s//txt_ids=negative_text_ids/}' "$DIFFUSERS_PATH/pipelines/flux/pipeline_flux.py" | |
# sed -i 's|if prompt_embeds is not None and negative_prompt_embeds is not None:|if False:|g' "$DIFFUSERS_PATH/pipelines/flux/pipeline_flux.py" | |
# # Create Schnell-compatible variant of Chroma by downloading both Chroma and Schnell safetensor files, and copying Chroma's matching weights over to Schnell. This works because lodestone *distilled* the guidance layers instead of completely pruning them, so we can actually just use Schnell's guidance stuff. This comes at the cost of bloating the model back to Schnell's original size, but it's probably the easiest approach for diffusers compatbility for now. | |
# CHROMA_VERSION="15" | |
# apt-get update && apt-get install aria2 -y && pip3 install safetensors | |
# cd /workspace | |
# aria2c -x 16 -s 16 -o chroma.safetensors "https://huggingface.co/lodestones/Chroma/resolve/main/chroma-unlocked-v${CHROMA_VERSION}.safetensors?download=true" | |
# aria2c -x 16 -s 16 -o flux1-schnell.safetensors "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors?download=true" | |
# python3 -c ' | |
# from safetensors import safe_open | |
# from safetensors.torch import save_file | |
# with safe_open("/workspace/chroma.safetensors", framework="pt", device="cpu") as chroma, safe_open("/workspace/flux1-schnell.safetensors", framework="pt", device="cpu") as schnell: | |
# chroma_tensors = {key: chroma.get_tensor(key) for key in chroma.keys()} | |
# schnell_tensors = {key: schnell.get_tensor(key) for key in schnell.keys()} | |
# matching_keys = set(chroma_tensors).intersection(schnell_tensors) | |
# for key in matching_keys: | |
# schnell_tensors[key] = chroma_tensors[key] | |
# save_file(schnell_tensors, "/workspace/chroma-schnell-compat.safetensors") | |
# ' | |
import torch | |
from diffusers import FluxTransformer2DModel, FluxPipeline, FlowMatchEulerDiscreteScheduler | |
from transformers import T5EncoderModel, T5Tokenizer | |
from optimum.quanto import freeze, qfloat8, quantize | |
# Example output image, with Chroma v15, using the prompt/settings below: | |
# https://github.com/user-attachments/assets/89158f03-b2cf-48ca-87e9-816a57007f12 | |
# Tip: As of writing, Chroma is still training, and has not undergone aesthetic fine-tuning, so you should add aesthetics-related keywords in your prompt to get highly aesthetic images. | |
# For example, here's a "maximum aesthetics" prompt template I've been using: | |
# "An artistically pleasing image. {YOUR PROMPT HERE}. The artist's skill is evident, and the piece is very well-executed with high aesthetics; clearly a masterpiece-level work of art. The image is high-quality, with impressive and exquisite attention to detail. The image is visually captivating and the quality of the background is impressive, and the composition is excellent. The level of detail is remarkable. High-quality, please." | |
prompt = "An aesthetically pleasing digital painting of a cat, holding a sign that says \"Chroma\". It has a charmingly painterly style, with visible brush strokes." | |
negative_prompt = "" | |
image_width = 1024 | |
image_height = 1024 | |
seed = 6 | |
cfg = 5 | |
num_inference_steps = 20 | |
transformer = FluxTransformer2DModel.from_single_file("chroma-schnell-compat.safetensors", torch_dtype=torch.bfloat16) | |
quantize(transformer, weights=qfloat8) | |
freeze(transformer) | |
transformer.to("cuda") | |
# Load T5 tokenizer and encoder outside of pipeline because we need to apply custom truncation | |
t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl") | |
t5_encoder = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2").to("cpu") | |
quantize(t5_encoder, weights=qfloat8) | |
freeze(t5_encoder) | |
def embed_prompt(prompt): | |
prompt_tokens = t5_tokenizer(prompt, padding="max_length", max_length=512, truncation=True, return_tensors="pt") | |
prompt_embeds = t5_encoder(prompt_tokens.input_ids, attention_mask=prompt_tokens.attention_mask)[0] | |
max_len = min(prompt_tokens.attention_mask.sum() + 1, 512) | |
prompt_embeds = prompt_embeds[:, :max_len] # Truncate to promptlength+1 (i.e. leave one padding token at the end) | |
return prompt_embeds | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", | |
transformer=None, | |
text_encoder_2=None, | |
torch_dtype=torch.bfloat16, | |
scheduler=FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=True, base_shift = 0.5, max_shift = 1.15, use_beta_sigmas = True), | |
) | |
pipe.transformer = transformer | |
pipe.to("cuda") | |
pipe.enable_model_cpu_offload() | |
image = pipe( | |
prompt_embeds=embed_prompt(prompt), | |
negative_prompt_embeds=embed_prompt(negative_prompt), | |
# Zeros for CLIP: | |
pooled_prompt_embeds=torch.zeros(1, 768, device="cuda"), | |
negative_pooled_prompt_embeds=torch.zeros(1, 768, device="cuda"), | |
num_inference_steps=num_inference_steps, | |
true_cfg_scale=cfg, | |
guidance_scale=0, | |
width=image_width, | |
height=image_height, | |
generator=torch.Generator("cuda").manual_seed(seed), | |
).images[0] | |
image.save(f"chroma_seed_{seed}.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment