import os
import torch
from einops import rearrange
from transformers import pipeline
from flux.sampling import denoise, get_noise, get_schedule, unpack, prepare
from flux.util import embed_watermark, load_ae, load_clip, load_flow_model, load_t5
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "./flux-reduce-overhead-cache"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.inference_mode():
t5 = load_t5("cuda", max_length=512)
clip = load_clip(device)
model = load_flow_model("flux-dev", device="cuda", hf_download=True)
ae = load_ae("flux-dev", device="cuda")
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
torch_device = torch.device(device)
rng = torch.Generator(device="cpu")
seed = 1
from functools import wraps
def benchmark_cuda(iters=1, skip_warmup=False):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not skip_warmup:
# Warm-up run
func(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
result = func(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event) / iters # in milliseconds
return result, elapsed_time
return wrapper
return decorator
@benchmark_cuda(iters=1, skip_warmup=True)
def get_flux_inputs(W, H, seed=1):
with torch.inference_mode():
x = get_noise(1, W, H, device=torch_device, dtype=torch.bfloat16, seed=seed)
timesteps = get_schedule(20, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
inp = prepare(t5=t5, clip=clip, img=x, prompt="A cat with a hat")
return inp, timesteps
@benchmark_cuda(iters=1, skip_warmup=True)
def benchmark_denoise(model, inp, timesteps, guidance):
with torch.inference_mode():
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance, use_tqdm=False)
return x
@benchmark_cuda(iters=1, skip_warmup=True)
def benchmark_ae(x, w, h):
with torch.inference_mode():
x = unpack(x.float(), w, h)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
return x
@benchmark_cuda(iters=1, skip_warmup=True)
def benchmark_imgproc(x):
with torch.inference_mode():
x = x.clamp(-1, 1)
x = embed_watermark(x.float())
x = rearrange(x[0], "c h w -> h w c")
return Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
# for example
# W = 1344, H = 768; you can get the input tensors (inp) as follows
inp, timesteps = get_flux_inputs(W, H, seed=1)
Note: We observe that the shapes of tensors for inp["img"]
and inp["img_ids"]
vary for different image dimensions (W x H)
# Hence, we can use mark_dynamic to handle this variability:
torch._dynamo.mark_dynamic(inp["img"], 1, min=1024, max=4096)
torch._dynamo.mark_dynamic(inp["img_ids"], 1, min=1024, max=4096)
model_compiled = torch.compile(model, mode="reduce-overhead")
widths = [1408, 1344, 1344, 1280, 1216, 1152, 1152, 1088, 1088, 1024, 1024, 960, 960, 896, 896, 832, 832, 768, 768, 704, 704, 640, 640, 576, 576, 512]
heights = [704, 704, 768, 768, 832, 832, 896, 896, 960, 960, 1024, 1024, 1088, 1088, 1152, 1152, 1216, 1280, 1344, 1344, 1408, 1536, 1600, 1664, 1728, 1024]
shapes = [{"width": w, "height": h} for w, h in zip(widths, heights)]
benchmark_results = {}
for shape in shapes:
w = shape["width"]
h = shape["height"]
emb_times = []
flux_times = []
ae_times = []
img_proc_times = []
num_runs = 2
for _ in range(num_runs):
(inp, timesteps), t_emb = get_flux_inputs(w, h, seed=1)
emb_times.append(t_emb)
x_latents, t_flux = benchmark_denoise(model_compiled, inp, timesteps, 4.0)
flux_times.append(t_flux)
x_decoded, t_vae = benchmark_ae(x_latents, w, h)
ae_times.append(t_vae)
img, t_img = benchmark_imgproc(x_decoded)
img_proc_times.append(t_img)
avg_emb_time = sum(emb_times) / num_runs
avg_flux_time = sum(flux_times) / num_runs
avg_ae_time = sum(ae_times) / num_runs
avg_img_proc_time = sum(img_proc_times) / num_runs
total_avg_time = avg_emb_time + avg_flux_time + avg_ae_time + avg_img_proc_time
benchmark_results[(w, h)] = {
"emb": avg_emb_time,
"flux": avg_flux_time,
"ae": avg_ae_time,
"img_proc": avg_img_proc_time,
"total": total_avg_time
}
print(f"Width: {w}, Height: {h}")
print(f"Average Embedding Time: {avg_emb_time:.4f}s")
print(f"Average Flux Time: {avg_flux_time:.4f}s")
print(f"Average VAE Time: {avg_ae_time:.4f}s")
print(f"Average Image Processing Time: {avg_img_proc_time:.4f}s")
print(f"Total Average Time: {total_avg_time:.4f}s")
print()
print("Summary of average times per shape:")
for (w, h), times in benchmark_results.items():
print(f"{w}x{h}: {times['total']:.4f}s")
Shape | Embedding | Flux | AE | Image Processing | Total |
---|---|---|---|---|---|
1408x704 | 28.6099 | 2980.3540 | 111.3747 | 130.1227 | 3250.4613 |
1344x704 | 28.4531 | 2762.1140 | 106.5527 | 122.3098 | 3019.4295 |
1344x768 | 28.4854 | 3079.0841 | 115.3007 | 131.6007 | 3354.4709 |
1280x768 | 28.8066 | 2933.4331 | 110.6439 | 125.2368 | 3198.1203 |
1216x832 | 28.3692 | 2997.4711 | 113.9547 | 125.8699 | 3265.6648 |
1152x832 | 28.4896 | 2908.8693 | 107.7649 | 120.4591 | 3165.5829 |
1152x896 | 28.4290 | 3089.0392 | 115.4949 | 128.0042 | 3360.9673 |
1088x896 | 28.6872 | 2936.0631 | 109.3522 | 125.3897 | 3199.4922 |
1088x960 | 28.4836 | 3104.3700 | 117.0150 | 129.8313 | 3379.6999 |
1024x960 | 28.4406 | 2928.7407 | 110.6539 | 123.0594 | 3190.8946 |
1024x1024 | 28.4637 | 3104.5332 | 116.9265 | 141.5787 | 3391.5022 |
960x1024 | 28.8508 | 2942.2141 | 110.4937 | 133.1331 | 3214.6918 |
960x1088 | 28.3978 | 3112.7682 | 117.5497 | 129.2910 | 3388.0067 |
896x1088 | 28.4170 | 2935.7063 | 109.5556 | 120.9787 | 3194.6576 |
896x1152 | 28.4508 | 3093.2997 | 115.4391 | 129.1997 | 3366.3894 |
832x1152 | 28.8455 | 2921.6091 | 107.8484 | 123.1314 | 3181.4345 |
832x1216 | 28.3867 | 3018.7200 | 113.5671 | 126.6900 | 3287.3637 |
768x1280 | 28.4646 | 2937.9644 | 110.7459 | 124.4567 | 3201.6316 |
768x1344 | 28.4804 | 3099.9993 | 115.6256 | 128.5179 | 3372.6232 |
704x1344 | 28.4841 | 2774.8572 | 106.9782 | 119.0040 | 3029.3234 |
704x1408 | 28.4302 | 2997.4070 | 111.3863 | 123.9205 | 3261.1441 |
640x1536 | 28.4634 | 2937.9116 | 111.1030 | 124.3213 | 3201.7994 |
640x1600 | 28.4554 | 3089.8820 | 114.6994 | 126.8730 | 3359.9098 |
576x1664 | 28.4345 | 2910.0455 | 107.7801 | 120.1497 | 3166.4098 |
576x1728 | 28.4017 | 2980.9469 | 111.7591 | 123.6519 | 3244.7595 |
512x1024 | 28.2752 | 1522.6722 | 58.5688 | 72.6834 | 1682.1996 |
def denoise(
model: Flux,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
timesteps: list[float],
guidance: float = 4.0,
use_tqdm: bool = True
):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
iterator = tqdm(
zip(timesteps[:-1], timesteps[1:]),
total=len(timesteps)-1,
desc="Denoising",
unit="it",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
) if use_tqdm else zip(timesteps[:-1], timesteps[1:])
for t_curr, t_prev in iterator:
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
Results are on H100