Skip to content

Instantly share code, notes, and snippets.

@gradjitta
Last active September 26, 2024 11:43
Show Gist options
  • Save gradjitta/1a9f7b41382cb98b42b06ca556f315a2 to your computer and use it in GitHub Desktop.
Save gradjitta/1a9f7b41382cb98b42b06ca556f315a2 to your computer and use it in GitHub Desktop.

Flux Dev (dynamic shapes) Benchmark

Setup

import os
import torch
from einops import rearrange
from transformers import pipeline
from flux.sampling import denoise, get_noise, get_schedule, unpack, prepare
from flux.util import embed_watermark, load_ae, load_clip, load_flow_model, load_t5

os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "./flux-reduce-overhead-cache"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with torch.inference_mode():
    t5 = load_t5("cuda", max_length=512)
    clip = load_clip(device)
    model = load_flow_model("flux-dev", device="cuda", hf_download=True)
    ae = load_ae("flux-dev", device="cuda")
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

torch_device = torch.device(device)
rng = torch.Generator(device="cpu")
seed = 1

Benchmark CUDA Decorator

from functools import wraps

def benchmark_cuda(iters=1, skip_warmup=False):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if not skip_warmup:
                # Warm-up run
                func(*args, **kwargs)
                torch.cuda.synchronize()

            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            start_event.record()
            for _ in range(iters):
                result = func(*args, **kwargs)
            end_event.record()

            torch.cuda.synchronize()
            elapsed_time = start_event.elapsed_time(end_event) / iters  # in milliseconds

            return result, elapsed_time
        return wrapper
    return decorator

Benchmark Functions

@benchmark_cuda(iters=1, skip_warmup=True)
def get_flux_inputs(W, H, seed=1):
    with torch.inference_mode():
        x = get_noise(1, W, H, device=torch_device, dtype=torch.bfloat16, seed=seed)
        timesteps = get_schedule(20, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
        inp = prepare(t5=t5, clip=clip, img=x, prompt="A cat with a hat")
    return inp, timesteps

@benchmark_cuda(iters=1, skip_warmup=True)
def benchmark_denoise(model, inp, timesteps, guidance):
    with torch.inference_mode():
        x = denoise(model, **inp, timesteps=timesteps, guidance=guidance, use_tqdm=False)
    return x

@benchmark_cuda(iters=1, skip_warmup=True)
def benchmark_ae(x, w, h):
    with torch.inference_mode():
        x = unpack(x.float(), w, h)
        with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
            x = ae.decode(x)
    return x

@benchmark_cuda(iters=1, skip_warmup=True)
def benchmark_imgproc(x):
    with torch.inference_mode():
        x = x.clamp(-1, 1)
        x = embed_watermark(x.float())
        x = rearrange(x[0], "c h w -> h w c")
    return Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())

Optimization with mark_dynamic

# for example 
# W = 1344, H = 768; you can get the input tensors (inp) as follows

inp, timesteps = get_flux_inputs(W, H, seed=1)

Note: We observe that the shapes of tensors for inp["img"] and inp["img_ids"] vary for different image dimensions (W x H)

# Hence, we can use mark_dynamic to handle this variability:
torch._dynamo.mark_dynamic(inp["img"], 1, min=1024, max=4096)
torch._dynamo.mark_dynamic(inp["img_ids"], 1, min=1024, max=4096)

model_compiled = torch.compile(model, mode="reduce-overhead")

Shapes used for this experiment

widths = [1408, 1344, 1344, 1280, 1216, 1152, 1152, 1088, 1088, 1024, 1024, 960, 960, 896, 896, 832, 832, 768, 768, 704, 704, 640, 640, 576, 576, 512]
heights = [704, 704, 768, 768, 832, 832, 896, 896, 960, 960, 1024, 1024, 1088, 1088, 1152, 1152, 1216, 1280, 1344, 1344, 1408, 1536, 1600, 1664, 1728, 1024]

shapes = [{"width": w, "height": h} for w, h in zip(widths, heights)]

Inference times over different shapes

benchmark_results = {}

for shape in shapes:
    w = shape["width"]
    h = shape["height"]
    emb_times = []
    flux_times = []
    ae_times = []
    img_proc_times = []
    num_runs = 2

    for _ in range(num_runs):
        (inp, timesteps), t_emb = get_flux_inputs(w, h, seed=1)
        emb_times.append(t_emb)

        x_latents, t_flux = benchmark_denoise(model_compiled, inp, timesteps, 4.0)
        flux_times.append(t_flux)

        x_decoded, t_vae = benchmark_ae(x_latents, w, h)
        ae_times.append(t_vae)

        img, t_img = benchmark_imgproc(x_decoded)
        img_proc_times.append(t_img)

    avg_emb_time = sum(emb_times) / num_runs
    avg_flux_time = sum(flux_times) / num_runs
    avg_ae_time = sum(ae_times) / num_runs
    avg_img_proc_time = sum(img_proc_times) / num_runs
    total_avg_time = avg_emb_time + avg_flux_time + avg_ae_time + avg_img_proc_time

    benchmark_results[(w, h)] = {
        "emb": avg_emb_time,
        "flux": avg_flux_time,
        "ae": avg_ae_time,
        "img_proc": avg_img_proc_time,
        "total": total_avg_time
    }

    print(f"Width: {w}, Height: {h}")
    print(f"Average Embedding Time: {avg_emb_time:.4f}s")
    print(f"Average Flux Time: {avg_flux_time:.4f}s")
    print(f"Average VAE Time: {avg_ae_time:.4f}s")
    print(f"Average Image Processing Time: {avg_img_proc_time:.4f}s")
    print(f"Total Average Time: {total_avg_time:.4f}s")
    print()

print("Summary of average times per shape:")
for (w, h), times in benchmark_results.items():
    print(f"{w}x{h}: {times['total']:.4f}s")

Results

Shape Embedding Flux AE Image Processing Total
1408x704 28.6099 2980.3540 111.3747 130.1227 3250.4613
1344x704 28.4531 2762.1140 106.5527 122.3098 3019.4295
1344x768 28.4854 3079.0841 115.3007 131.6007 3354.4709
1280x768 28.8066 2933.4331 110.6439 125.2368 3198.1203
1216x832 28.3692 2997.4711 113.9547 125.8699 3265.6648
1152x832 28.4896 2908.8693 107.7649 120.4591 3165.5829
1152x896 28.4290 3089.0392 115.4949 128.0042 3360.9673
1088x896 28.6872 2936.0631 109.3522 125.3897 3199.4922
1088x960 28.4836 3104.3700 117.0150 129.8313 3379.6999
1024x960 28.4406 2928.7407 110.6539 123.0594 3190.8946
1024x1024 28.4637 3104.5332 116.9265 141.5787 3391.5022
960x1024 28.8508 2942.2141 110.4937 133.1331 3214.6918
960x1088 28.3978 3112.7682 117.5497 129.2910 3388.0067
896x1088 28.4170 2935.7063 109.5556 120.9787 3194.6576
896x1152 28.4508 3093.2997 115.4391 129.1997 3366.3894
832x1152 28.8455 2921.6091 107.8484 123.1314 3181.4345
832x1216 28.3867 3018.7200 113.5671 126.6900 3287.3637
768x1280 28.4646 2937.9644 110.7459 124.4567 3201.6316
768x1344 28.4804 3099.9993 115.6256 128.5179 3372.6232
704x1344 28.4841 2774.8572 106.9782 119.0040 3029.3234
704x1408 28.4302 2997.4070 111.3863 123.9205 3261.1441
640x1536 28.4634 2937.9116 111.1030 124.3213 3201.7994
640x1600 28.4554 3089.8820 114.6994 126.8730 3359.9098
576x1664 28.4345 2910.0455 107.7801 120.1497 3166.4098
576x1728 28.4017 2980.9469 111.7591 123.6519 3244.7595
512x1024 28.2752 1522.6722 58.5688 72.6834 1682.1996

Modified denoise function to include tqdm

def denoise(
    model: Flux,
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    timesteps: list[float],
    guidance: float = 4.0,
    use_tqdm: bool = True
):
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    iterator = tqdm(
        zip(timesteps[:-1], timesteps[1:]),
        total=len(timesteps)-1,
        desc="Denoising",
        unit="it",
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
    ) if use_tqdm else zip(timesteps[:-1], timesteps[1:])

    for t_curr, t_prev in iterator:
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        pred = model(
            img=img,
            img_ids=img_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t_vec,
            guidance=guidance_vec,
        )
        img = img + (t_prev - t_curr) * pred

    return img
@gradjitta
Copy link
Author

Results are on H100

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment