Skip to content

Instantly share code, notes, and snippets.

@cbensimon
Created September 4, 2025 15:22
Show Gist options
  • Save cbensimon/1b887c93e40d1b16b3c29692751b1988 to your computer and use it in GitHub Desktop.
Save cbensimon/1b887c93e40d1b16b3c29692751b1988 to your computer and use it in GitHub Desktop.
PyTorch regional AoT compilation with dynamic shapes (QwenImageEdit)
"""
"""
import spaces
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import load_image
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
from torch.utils._pytree import tree_map
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16).to('cuda')
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png")
image = image.convert('RGB')
prompt = "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
pipe_kwargs = {
"image": image,
"prompt": prompt,
"negative_prompt": " ",
"true_cfg_scale": 4.0,
"num_inference_steps": 50,
"generator": torch.manual_seed(0),
}
with spaces.aoti_capture(pipe.transformer.transformer_blocks[0]) as call:
pipe(**pipe_kwargs)
TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
},
'encoder_hidden_states': {
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
},
'encoder_hidden_states_mask': {
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
},
'image_rotary_emb': ({
0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
}, {
0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
}),
}
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
exported = torch.export.export(
mod=pipe.transformer.transformer_blocks[0],
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
compiled = spaces.aoti_compile(exported)
for block in pipe.transformer.transformer_blocks:
weights = ZeroGPUWeights(block.state_dict())
compiled_block = ZeroGPUCompiledModel(compiled.archive_file, weights)
block.forward = compiled_block
pipe(**pipe_kwargs).images[0].save('edited.png')
@cbensimon
Copy link
Author

edited

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment