Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active February 2, 2025 17:54
Show Gist options
  • Save sayakpaul/de0eeeb6d08ba30a37dcf0bc9dacc5c5 to your computer and use it in GitHub Desktop.
Save sayakpaul/de0eeeb6d08ba30a37dcf0bc9dacc5c5 to your computer and use it in GitHub Desktop.
Shows how to AoT compile the Flux.1 Dev Transformer with int8 quant and perform inference.
import torch
from diffusers import FluxTransformer2DModel
import torch.utils.benchmark as benchmark
from torchao.quantization import quantize_, int8_weight_only
from torchao.utils import unwrap_tensor_subclass
import torch._inductor
torch._inductor.config.mixed_mm_choice = "triton"
def get_example_inputs():
example_inputs = torch.load("serialized_inputs.pt", weights_only=True)
example_inputs = {k: v.to("cuda") for k, v in example_inputs.items()}
example_inputs.update({"joint_attention_kwargs": None, "return_dict": False})
return example_inputs
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
@torch.no_grad()
def load_model():
model = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16
).to("cuda")
return model
def aot_compile(name, model, **sample_kwargs):
path = f"./{name}.pt2"
options = {
"max_autotune": True,
"triton.cudagraphs": True,
}
return torch._inductor.aoti_compile_and_package(
torch.export.export(model, (), sample_kwargs),
(),
sample_kwargs,
package_path=path,
inductor_configs=options,
)
def aot_load(path):
return torch._inductor.aoti_load_package(path)
@torch.no_grad()
def f(model, **kwargs):
return model(**kwargs)
if __name__ == "__main__":
model = load_model()
quantize_(model, int8_weight_only())
inputs1 = get_example_inputs()
unwrap_tensor_subclass(model)
path = aot_compile("bs_1_1024", model, **inputs1)
print(f"AoT compiled path {path}")
compiled_func = aot_load(path)
print(f"{compiled_func(**inputs1)[0].shape=}")
for _ in range(5):
_ = compiled_func(**inputs1)[0]
time = benchmark_fn(f, compiled_func, **inputs1)
print(f"{time=} seconds.")
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=None,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer = torch._inductor.aoti_load_package("./bs_1_1024.pt2")
image = pipeline("cute dog", guidance_scale=3.5, max_sequence_length=512, num_inference_steps=50).images[0]
image.save("aot_compiled.png")

inference.py produces:

image

You're welcome to try out other quantization techniques from torchao and benefit from torch.compile(). diffusers-torchao provides a handy reference.

Library versions:

  • diffusers: Installed from the main.
  • torchao: Installed from the main.
  • torch: 2.6.0.dev20241027+cu121

Tested on H100.

serialized_inputs.pt in aot_compile_with_int8_quant.py was obtained by serializing the inputs to self.transformer (from here). You can download it from here.

Additionally, to perform inference with this AoT compiled binary with DiffusionPipeline as shown in inference.py, the following changes are needed to the pipeline_flux.py file:

diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 040d935f1..f24cd28c5 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -680,7 +680,7 @@ class FluxPipeline(
         )
 
         # 4. Prepare latent variables
-        num_channels_latents = self.transformer.config.in_channels // 4
+        num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.transformer, torch.nn.Module) else 16
         latents, latent_image_ids = self.prepare_latents(
             batch_size * num_images_per_prompt,
             num_channels_latents,
@@ -714,7 +714,7 @@ class FluxPipeline(
         self._num_timesteps = len(timesteps)
 
         # handle guidance
-        if self.transformer.config.guidance_embeds:
+        if (isinstance(self.transformer, torch.nn.Module) and self.transformer.config.guidance_embeds) or isinstance(self.transformer, Callable):
             guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
             guidance = guidance.expand(latents.shape[0])
         else:

The compiled binary file ("bs_1_1024.pt2") used in inference.py can be found here.

Thanks to PyTorch folks (especially @jerryzh168) who provided guidance in this thread.

@Manojbhat09
Copy link

maybe this needs to a note on using AOT:

compiling with aot, while it reduces the latency the resulting images are widely different than the the one without compilation,
This is true for both dev and schnell models.
Hence the similarity of images at the output might be around 0.5 of the original.

when flux.1.schnell transformer is aot compiled-

prompt-"hydrosulphuryl, matterless, autecologist, sensory, supersyndicate, bestialism, Murph, unsimple"
original (non-compiled)
original_0
AOT compiled
compiled_0

prompt-"nonenclosure, Waiilatpuan, intertrinitarian, ecclesiast, moschatelline, heaver, metamorphose, alkaliferous"
original (non-compiled)
original_1

AOT compiled
compiled_1

prompt-"forayer, batterfang, diphead, Semaeostomata, backless"
original (non-compiled)
original_2

AOT compiled
compiled_2

prompt-"Guatoan, otalgic, crumenal, Protohydra, colporrhaphy, unhemmed, archiblastic, bosher"
original (non-compiled)
original_3

AOT compiled
compiled_3

@sayakpaul
Copy link
Author

  • Did you use the same seed?
  • What is you JiT compile code?
  • Can you present the results in a tabular format? Hard to parse the results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment