inference.py
produces:
You're welcome to try out other quantization techniques from torchao
and benefit from torch.compile()
. diffusers-torchao
provides a handy reference.
Library versions:
diffusers
: Installed from themain
.torchao
: Installed from themain
.torch
: 2.6.0.dev20241027+cu121
Tested on H100.
serialized_inputs.pt
in aot_compile_with_int8_quant.py
was obtained by serializing the inputs to
self.transformer
(from here). You can download it from
here.
Additionally, to perform inference with this AoT compiled binary with DiffusionPipeline
as shown in inference.py
, the following changes are needed to the pipeline_flux.py
file:
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 040d935f1..f24cd28c5 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -680,7 +680,7 @@ class FluxPipeline(
)
# 4. Prepare latent variables
- num_channels_latents = self.transformer.config.in_channels // 4
+ num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.transformer, torch.nn.Module) else 16
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -714,7 +714,7 @@ class FluxPipeline(
self._num_timesteps = len(timesteps)
# handle guidance
- if self.transformer.config.guidance_embeds:
+ if (isinstance(self.transformer, torch.nn.Module) and self.transformer.config.guidance_embeds) or isinstance(self.transformer, Callable):
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
The compiled binary file ("bs_1_1024.pt2") used in inference.py
can be found here.
Thanks to PyTorch folks (especially @jerryzh168) who provided guidance in this thread.
maybe this needs to a note on using AOT:
compiling with aot, while it reduces the latency the resulting images are widely different than the the one without compilation,
This is true for both dev and schnell models.
Hence the similarity of images at the output might be around 0.5 of the original.
when flux.1.schnell transformer is aot compiled-
prompt-"hydrosulphuryl, matterless, autecologist, sensory, supersyndicate, bestialism, Murph, unsimple"


original (non-compiled)
AOT compiled
prompt-"nonenclosure, Waiilatpuan, intertrinitarian, ecclesiast, moschatelline, heaver, metamorphose, alkaliferous"

original (non-compiled)
AOT compiled

prompt-"forayer, batterfang, diphead, Semaeostomata, backless"

original (non-compiled)
AOT compiled

prompt-"Guatoan, otalgic, crumenal, Protohydra, colporrhaphy, unhemmed, archiblastic, bosher"

original (non-compiled)
AOT compiled
