Created
August 18, 2025 09:40
-
-
Save a-r-r-o-w/d34c53ca9d4f69cd45872931b0f3855d to your computer and use it in GitHub Desktop.
Flux with cuda stream
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import contextlib | |
| import functools | |
| import pathlib | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Callable, List, Literal, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import torch.distributed._functional_collectives as funcol | |
| import torch.profiler._utils | |
| import torch._dynamo.config | |
| import torch._inductor.config | |
| import torch._higher_order_ops.auto_functionalize as af | |
| from torch.profiler import profile, record_function, ProfilerActivity | |
| from diffusers import AutoencoderKL | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin | |
| from diffusers.models.attention import FeedForward | |
| from diffusers.models.embeddings import get_1d_rotary_pos_embed | |
| from diffusers.models.cache_utils import CacheMixin | |
| from diffusers.models.embeddings import ( | |
| CombinedTimestepGuidanceTextProjEmbeddings, | |
| CombinedTimestepTextProjEmbeddings, | |
| ) | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from kernels import get_kernel | |
| from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
| try: | |
| from flash_attn import flash_attn_func | |
| except: | |
| print("Flash Attention 2 not found.") | |
| try: | |
| from flash_attn_interface import flash_attn_func as flash_attn_3_func | |
| except: | |
| print("Flash Attention 3 not found.") | |
| def apply_flags(): | |
| torch._dynamo.config.inline_inbuilt_nn_modules = False | |
| torch._dynamo.config.cache_size_limit = 128 | |
| torch._dynamo.config.error_on_recompile = True | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.disable_progress = False | |
| torch._inductor.config.fx_graph_cache = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.aggressive_fusion = True | |
| torch._inductor.config.shape_padding = True | |
| torch._inductor.config.triton.enable_persistent_tma_matmul = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| af.auto_functionalized_v2._cacheable = True | |
| af.auto_functionalized._cacheable = True | |
| ROPE_PRECISION = torch.bfloat16 | |
| SPATIAL_COMPRESSION_RATIO = 8 | |
| PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR = 2 | |
| T5_SEQUENCE_LENGTH = 512 | |
| DEFAULT_HEIGHT = 1024 | |
| DEFAULT_WIDTH = 1024 | |
| PATCH_SIZE = 1 | |
| LATENT_HEIGHT = DEFAULT_HEIGHT // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2 | |
| LATENT_WIDTH = DEFAULT_WIDTH // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2 | |
| SUPPORTED_BUCKET_LENGTHS = list(range(128, 512 + 1, 64)) | |
| SUPPORTED_GUIDANCE_SCALES = [i / 2 for i in range(41)] # 0, 0.5, 1.0, ..., 20.0 | |
| MIN_INFERENCE_STEPS = 2 | |
| MAX_INFERENCE_STEPS = 50 | |
| ATTENTION_OP = None | |
| STREAM = None | |
| BASE_IMAGE_SEQ_LEN = 256 | |
| MAX_IMAGE_SEQ_LEN = 4096 | |
| BASE_SHIFT = 0.5 | |
| MAX_SHIFT = 1.15 | |
| M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN) | |
| B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN | |
| @dataclass | |
| class ContextParallelOptions: | |
| ring_degree: int = None | |
| ulysses_degree: int = None | |
| mode: Literal["ring", "ulysses", "unified"] = "ring" | |
| mesh: dist.DeviceMesh | None = None | |
| convert_to_fp32: bool = True | |
| attention_op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]] | None = None | |
| _flattened_mesh: dist.DeviceMesh = None | |
| _ring_mesh: dist.DeviceMesh = None | |
| _ulysses_mesh: dist.DeviceMesh = None | |
| _ring_local_rank: int = None | |
| _ulysses_local_rank: int = None | |
| cp_options = ContextParallelOptions() | |
| class AdaLayerNormContinuous(torch.nn.Module): | |
| def __init__( | |
| self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True | |
| ): | |
| super().__init__() | |
| self.linear = torch.nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) | |
| self.norm = torch.nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
| emb = self.linear(emb) | |
| scale, shift = emb.unsqueeze(1).chunk(2, dim=-1) | |
| x = self.norm(x) | |
| x = torch.addcmul(shift, x, 1 + scale) | |
| return x | |
| class AdaLayerNormZeroSingle(torch.nn.Module): | |
| def __init__(self, embedding_dim: int, bias=True): | |
| super().__init__() | |
| self.linear = torch.nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) | |
| self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
| shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) | |
| x = self.norm(x) | |
| x = torch.addcmul(shift_msa, x, 1 + scale_msa) | |
| return x, gate_msa | |
| class AdaLayerNormZero(torch.nn.Module): | |
| def __init__(self, embedding_dim: int, bias=True): | |
| super().__init__() | |
| self.linear = torch.nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) | |
| self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) | |
| x = self.norm(x) | |
| x = torch.addcmul(shift_msa, x, 1 + scale_msa) | |
| return x, gate_msa, shift_mlp, scale_mlp, gate_mlp | |
| class Attention(torch.nn.Module): | |
| def __init__( | |
| self, | |
| query_dim: int, | |
| heads: int = 8, | |
| dim_head: int = 64, | |
| dropout: float = 0.0, | |
| bias: bool = False, | |
| qk_norm: Optional[str] = None, | |
| added_kv_proj_dim: Optional[int] = None, | |
| added_proj_bias: Optional[bool] = True, | |
| out_bias: bool = True, | |
| eps: float = 1e-5, | |
| out_dim: int = None, | |
| context_pre_only=None, | |
| pre_only=False, | |
| elementwise_affine: bool = True, | |
| ): | |
| super().__init__() | |
| assert qk_norm == "rms_norm", "Flux uses RMSNorm" | |
| self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |
| self.query_dim = query_dim | |
| self.use_bias = bias | |
| self.dropout = dropout | |
| self.out_dim = out_dim if out_dim is not None else query_dim | |
| self.context_pre_only = context_pre_only | |
| self.pre_only = pre_only | |
| self.heads = out_dim // dim_head if out_dim is not None else heads | |
| self.added_proj_bias = added_proj_bias | |
| self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
| self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
| self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) | |
| self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) | |
| self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) | |
| if not self.pre_only: | |
| self.to_out = torch.nn.ModuleList([]) | |
| self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) | |
| if added_kv_proj_dim is not None: | |
| self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) | |
| self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) | |
| self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
| self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
| self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
| self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| cos, sin = image_rotary_emb if image_rotary_emb is not None else (None, None) | |
| if encoder_hidden_states is not None: | |
| if STREAM is not None: | |
| STREAM.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(STREAM): | |
| if self.fused_projections: | |
| query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) | |
| else: | |
| query_c = self.add_q_proj(encoder_hidden_states) | |
| key_c = self.add_k_proj(encoder_hidden_states) | |
| value_c = self.add_v_proj(encoder_hidden_states) | |
| query_c, key_c, value_c = ( | |
| x.unflatten(2, (self.heads, -1)) for x in (query_c, key_c, value_c) | |
| ) | |
| query_c = self.norm_added_q(query_c) | |
| key_c = self.norm_added_k(key_c) | |
| else: | |
| if self.fused_projections: | |
| query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) | |
| else: | |
| query_c = self.add_q_proj(encoder_hidden_states) | |
| key_c = self.add_k_proj(encoder_hidden_states) | |
| value_c = self.add_v_proj(encoder_hidden_states) | |
| query_c, key_c, value_c = ( | |
| x.unflatten(2, (self.heads, -1)) for x in (query_c, key_c, value_c) | |
| ) | |
| query_c = self.norm_added_q(query_c) | |
| key_c = self.norm_added_k(key_c) | |
| if self.fused_projections: | |
| query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1) | |
| else: | |
| query = self.to_q(hidden_states) | |
| key = self.to_k(hidden_states) | |
| value = self.to_v(hidden_states) | |
| query, key, value = (x.unflatten(2, (self.heads, -1)) for x in (query, key, value)) | |
| query = self.norm_q(query) | |
| key = self.norm_k(key) | |
| if STREAM is not None: | |
| torch.cuda.current_stream().wait_stream(STREAM) | |
| if encoder_hidden_states is not None: | |
| query = torch.cat([query_c, query], dim=1) | |
| key = torch.cat([key_c, key], dim=1) | |
| value = torch.cat([value_c, value], dim=1) | |
| if image_rotary_emb is not None: | |
| x_real, x_imag = query.unflatten(-1, (-1, 2)).unbind(-1) | |
| x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
| query = (query.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(query) | |
| x_real, x_imag = key.unflatten(-1, (-1, 2)).unbind(-1) | |
| x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
| key = (key.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(key) | |
| hidden_states, lse = ATTENTION_OP(query, key, value) | |
| hidden_states = hidden_states.flatten(2) | |
| if encoder_hidden_states is not None: | |
| encoder_hidden_states, hidden_states = torch.split_with_sizes( | |
| hidden_states, | |
| [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], | |
| dim=1, | |
| ) | |
| hidden_states = self.to_out[0](hidden_states) | |
| encoder_hidden_states = self.to_add_out(encoder_hidden_states) | |
| return hidden_states, encoder_hidden_states | |
| return hidden_states | |
| @torch.no_grad() | |
| def fuse_projections(self): | |
| device = self.to_q.weight.data.device | |
| dtype = self.to_q.weight.data.dtype | |
| concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) | |
| in_features = concatenated_weights.shape[1] | |
| out_features = concatenated_weights.shape[0] | |
| self.to_qkv = torch.nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) | |
| self.to_qkv.weight.copy_(concatenated_weights) | |
| if self.use_bias: | |
| concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) | |
| self.to_qkv.bias.copy_(concatenated_bias) | |
| if ( | |
| getattr(self, "add_q_proj", None) is not None | |
| and getattr(self, "add_k_proj", None) is not None | |
| and getattr(self, "add_v_proj", None) is not None | |
| ): | |
| concatenated_weights = torch.cat( | |
| [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] | |
| ) | |
| in_features = concatenated_weights.shape[1] | |
| out_features = concatenated_weights.shape[0] | |
| self.to_added_qkv = torch.nn.Linear( | |
| in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype | |
| ) | |
| self.to_added_qkv.weight.copy_(concatenated_weights) | |
| if self.added_proj_bias: | |
| concatenated_bias = torch.cat( | |
| [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] | |
| ) | |
| self.to_added_qkv.bias.copy_(concatenated_bias) | |
| for layer in ("to_q", "to_k", "to_v", "to_added_q", "to_added_k", "to_added_v"): | |
| if hasattr(self, layer): | |
| module = getattr(self, layer) | |
| module.to("meta") | |
| delattr(self, layer) | |
| self.fused_projections = True | |
| class FluxPosEmbed(torch.nn.Module): | |
| def __init__(self, theta: int, axes_dim: List[int]): | |
| super().__init__() | |
| self.theta = theta | |
| self.axes_dim = axes_dim | |
| def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
| n_axes = ids.shape[-1] | |
| cos_out = [] | |
| sin_out = [] | |
| for i in range(n_axes): | |
| cos, sin = get_1d_rotary_pos_embed( | |
| self.axes_dim[i], | |
| ids[:, i], | |
| theta=self.theta, | |
| repeat_interleave_real=True, | |
| use_real=True, | |
| freqs_dtype=ROPE_PRECISION, | |
| ) | |
| cos_out.append(cos) | |
| sin_out.append(sin) | |
| freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None] | |
| freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None] | |
| return freqs_cos, freqs_sin | |
| class FluxSingleTransformerBlock(torch.nn.Module): | |
| def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): | |
| super().__init__() | |
| self.mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.norm = AdaLayerNormZeroSingle(dim) | |
| self.proj_mlp = torch.nn.Linear(dim, self.mlp_hidden_dim) | |
| self.act_mlp = torch.nn.GELU(approximate="tanh") | |
| self.attn = Attention( | |
| query_dim=dim, | |
| dim_head=attention_head_dim, | |
| heads=num_attention_heads, | |
| out_dim=dim, | |
| bias=True, | |
| qk_norm="rms_norm", | |
| eps=1e-6, | |
| pre_only=True, | |
| ) | |
| self.proj_out = torch.nn.Linear(dim + self.mlp_hidden_dim, dim) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| norm_hidden_states, gate = self.norm(hidden_states, emb=temb) | |
| if STREAM is not None: | |
| STREAM.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(STREAM): | |
| mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) | |
| else: | |
| mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) | |
| attn_output = self.attn( | |
| hidden_states=norm_hidden_states, | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| if STREAM is not None: | |
| torch.cuda.current_stream().wait_stream(STREAM) | |
| attn_mlp_hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) | |
| proj_out = self.proj_out(attn_mlp_hidden_states) | |
| hidden_states = torch.addcmul(hidden_states, gate, proj_out) | |
| return hidden_states | |
| class FluxTransformerBlock(torch.nn.Module): | |
| def __init__( | |
| self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 | |
| ): | |
| super().__init__() | |
| self.norm1 = AdaLayerNormZero(dim) | |
| self.norm1_context = AdaLayerNormZero(dim) | |
| self.attn = Attention( | |
| query_dim=dim, | |
| added_kv_proj_dim=dim, | |
| dim_head=attention_head_dim, | |
| heads=num_attention_heads, | |
| out_dim=dim, | |
| context_pre_only=False, | |
| bias=True, | |
| qk_norm=qk_norm, | |
| eps=eps, | |
| ) | |
| self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
| self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
| self.norm2_context = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
| self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| temb, temb_context = temb.chunk(2, dim=-1) | |
| if STREAM is not None: | |
| STREAM.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(STREAM): | |
| norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
| encoder_hidden_states, emb=temb_context | |
| ) | |
| else: | |
| norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
| encoder_hidden_states, emb=temb_context | |
| ) | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | |
| if STREAM is not None: | |
| torch.cuda.current_stream().wait_stream(STREAM) | |
| attn_output, context_attn_output = self.attn( | |
| hidden_states=norm_hidden_states, | |
| encoder_hidden_states=norm_encoder_hidden_states, | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| if STREAM is not None: | |
| STREAM.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(STREAM): | |
| encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output) | |
| norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
| norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp) | |
| context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
| encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output) | |
| hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output) | |
| norm_hidden_states = self.norm2(hidden_states) | |
| norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp) | |
| ff_output = self.ff(norm_hidden_states) | |
| hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output) | |
| torch.cuda.current_stream().wait_stream(STREAM) | |
| else: | |
| hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output) | |
| encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output) | |
| norm_hidden_states = self.norm2(hidden_states) | |
| norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
| norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp) | |
| norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp) | |
| ff_output = self.ff(norm_hidden_states) | |
| context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
| hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output) | |
| encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output) | |
| return encoder_hidden_states, hidden_states | |
| class FluxTransformer2DModel( | |
| ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin | |
| ): | |
| @register_to_config | |
| def __init__( | |
| self, | |
| patch_size: int = 1, | |
| in_channels: int = 64, | |
| out_channels: Optional[int] = None, | |
| num_layers: int = 19, | |
| num_single_layers: int = 38, | |
| attention_head_dim: int = 128, | |
| num_attention_heads: int = 24, | |
| joint_attention_dim: int = 4096, | |
| pooled_projection_dim: int = 768, | |
| guidance_embeds: bool = False, | |
| axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), | |
| ): | |
| super().__init__() | |
| self.out_channels = out_channels or in_channels | |
| self.inner_dim = num_attention_heads * attention_head_dim | |
| self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) | |
| text_time_guidance_cls = ( | |
| CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings | |
| ) | |
| self.time_text_embed = text_time_guidance_cls( | |
| embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim | |
| ) | |
| self.context_embedder = torch.nn.Linear(joint_attention_dim, self.inner_dim) | |
| self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) | |
| self.transformer_blocks = torch.nn.ModuleList( | |
| [ | |
| FluxTransformerBlock( | |
| dim=self.inner_dim, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.single_transformer_blocks = torch.nn.ModuleList( | |
| [ | |
| FluxSingleTransformerBlock( | |
| dim=self.inner_dim, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| ) | |
| for _ in range(num_single_layers) | |
| ] | |
| ) | |
| self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) | |
| self.proj_out = torch.nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| conditioning: torch.Tensor, | |
| image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], | |
| dt: torch.Tensor, | |
| ) -> torch.Tensor: | |
| x_t = hidden_states | |
| hidden_states = self.x_embedder(hidden_states) | |
| if STREAM is not None: | |
| STREAM.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(STREAM): | |
| adaln_linear_dual_stream_states = self.adaln_linear(conditioning).unsqueeze(1).chunk(self.config.num_layers, dim=-1) | |
| else: | |
| adaln_linear_dual_stream_states = self.adaln_linear(conditioning).unsqueeze(1).chunk(self.config.num_layers, dim=-1) | |
| adaln_linear_single_stream_states = self.adaln_linear_single(conditioning).unsqueeze(1).chunk(self.config.num_single_layers, dim=-1) | |
| if STREAM is not None: | |
| torch.cuda.current_stream().wait_stream(STREAM) | |
| for i, block in enumerate(self.transformer_blocks): | |
| encoder_hidden_states, hidden_states = block( | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| temb=adaln_linear_dual_stream_states[i], | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
| for i, block in enumerate(self.single_transformer_blocks): | |
| hidden_states = block( | |
| hidden_states=hidden_states, | |
| temb=adaln_linear_single_stream_states[i], | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] | |
| hidden_states = self.norm_out(hidden_states, conditioning) | |
| velocity = self.proj_out(hidden_states) | |
| x = x_t + dt * velocity | |
| return x | |
| @torch.no_grad() | |
| def fuse_qkv_(model: FluxTransformer2DModel) -> FluxTransformer2DModel: | |
| for submodule in model.modules(): | |
| if not isinstance(submodule, Attention): | |
| continue | |
| submodule.fuse_projections() | |
| @torch.no_grad() | |
| def fuse_adaln_linear_(model: FluxTransformer2DModel) -> FluxTransformer2DModel: | |
| adaln_linear_weights = [] | |
| adaln_linear_biases = [] | |
| for block in model.transformer_blocks: | |
| adaln_linear_weights.append(block.norm1.linear.weight.data.clone()) | |
| adaln_linear_weights.append(block.norm1_context.linear.weight.data.clone()) | |
| adaln_linear_biases.append(block.norm1.linear.bias.data.clone()) | |
| adaln_linear_biases.append(block.norm1_context.linear.bias.data.clone()) | |
| block.norm1.linear.to("meta") | |
| block.norm1_context.linear.to("meta") | |
| del block.norm1.linear, block.norm1_context.linear | |
| adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0) | |
| adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0) | |
| in_features = adaln_linear_weights.shape[1] | |
| out_features = adaln_linear_weights.shape[0] | |
| model.adaln_linear = torch.nn.Linear( | |
| in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype | |
| ) | |
| model.adaln_linear.weight.copy_(adaln_linear_weights) | |
| model.adaln_linear.bias.copy_(adaln_linear_biases) | |
| adaln_linear_weights = [] | |
| adaln_linear_biases = [] | |
| for block in model.single_transformer_blocks: | |
| adaln_linear_weights.append(block.norm.linear.weight.data.clone()) | |
| adaln_linear_biases.append(block.norm.linear.bias.data.clone()) | |
| block.norm.linear.to("meta") | |
| del block.norm.linear | |
| adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0) | |
| adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0) | |
| in_features = adaln_linear_weights.shape[1] | |
| out_features = adaln_linear_weights.shape[0] | |
| model.adaln_linear_single = torch.nn.Linear( | |
| in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype | |
| ) | |
| model.adaln_linear_single.weight.copy_(adaln_linear_weights) | |
| model.adaln_linear_single.bias.copy_(adaln_linear_biases) | |
| def prepare_clip_embeddings( | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| prompt: str, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| max_length: int = 77, | |
| ) -> torch.Tensor: | |
| prompt = [prompt] | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
| prompt_embeds = prompt_embeds.pooler_output.to(dtype) | |
| return prompt_embeds | |
| def prepare_t5_embeddings( | |
| text_encoder: T5EncoderModel, | |
| tokenizer: T5TokenizerFast, | |
| prompt: str, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| max_length: int = 512, | |
| enable_prompt_length_bucketing: bool = False, | |
| ) -> torch.Tensor: | |
| prompt = [prompt] | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| if enable_prompt_length_bucketing: | |
| attention_mask = text_inputs.attention_mask | |
| num_text_tokens = attention_mask.sum(dim=1).max().item() | |
| max_length = min( | |
| SUPPORTED_BUCKET_LENGTHS, key=lambda x: abs(x - num_text_tokens) if x >= num_text_tokens else float("inf") | |
| ) | |
| text_input_ids = text_input_ids[:, :max_length] | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0] | |
| prompt_embeds = prompt_embeds.to(dtype) | |
| return prompt_embeds, max_length | |
| @functools.lru_cache(maxsize=8) | |
| def prepare_latent_image_ids(height: int, width: int, device: torch.device, dtype: torch.dtype): | |
| latent_image_ids = torch.zeros(height, width, 3) | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
| latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
| latent_image_ids = latent_image_ids.reshape( | |
| latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
| ).contiguous() | |
| return latent_image_ids.to(device=device, dtype=dtype) | |
| def precompute_guidance_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype): | |
| embeds = {} | |
| for guidance_scale in SUPPORTED_GUIDANCE_SCALES: | |
| guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) | |
| guidance = transformer.time_text_embed.guidance_embedder( | |
| transformer.time_text_embed.time_proj(guidance * 1000.0).to(dtype) | |
| ) | |
| embeds[f"{guidance_scale:.1f}"] = guidance | |
| return embeds | |
| def precompute_timestep_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype): | |
| embeds = {} | |
| image_seq_len = LATENT_HEIGHT * LATENT_WIDTH | |
| for num_inference_steps in range(MIN_INFERENCE_STEPS, MAX_INFERENCE_STEPS + 1): | |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
| mu = B + image_seq_len * M | |
| sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0) | |
| sigmas = torch.from_numpy(sigmas).to(device, dtype=torch.float32) | |
| sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]) | |
| timesteps = (sigmas * 1000.0).to(dtype) | |
| temb = transformer.time_text_embed.time_proj(timesteps) | |
| temb = transformer.time_text_embed.timestep_embedder(temb.to(dtype)) | |
| embeds[num_inference_steps] = (sigmas, temb) | |
| return embeds | |
| def precompute_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype, save_dir: str): | |
| save_dir = pathlib.Path(save_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| guidance_path = save_dir / "guidance_embeds.pt" | |
| timestep_path = save_dir / "timestep_embeds.pt" | |
| if guidance_path.exists(): | |
| guidance_embeds = torch.load(guidance_path, map_location=device, weights_only=True) | |
| print(f'Loaded precomputed guidance embeddings from "{guidance_path.as_posix()}"') | |
| else: | |
| guidance_embeds = precompute_guidance_embeds(transformer, device, dtype) | |
| if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0: | |
| torch.save(guidance_embeds, guidance_path.as_posix()) | |
| print(f'Precomputed guidance embeddings saved to "{save_dir.as_posix()}"') | |
| if timestep_path.exists(): | |
| timestep_embeds = torch.load(timestep_path, map_location=device, weights_only=True) | |
| print(f'Loaded precomputed timestep embeddings from "{timestep_path.as_posix()}"') | |
| else: | |
| timestep_embeds = precompute_timestep_embeds(transformer, device, dtype) | |
| if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0: | |
| torch.save(timestep_embeds, timestep_path.as_posix()) | |
| print(f'Precomputed timestep embeddings saved to "{save_dir.as_posix()}"') | |
| return guidance_embeds, timestep_embeds | |
| @torch.compile | |
| def pointwise_add3_silu(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
| return torch.nn.functional.silu(x + y + z) | |
| def capture_cudagraph( | |
| model: FluxTransformer2DModel, | |
| latents: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| conditioning: torch.Tensor, | |
| image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], | |
| dt: torch.Tensor, | |
| ): | |
| print("Warming up CUDAGraph capture") | |
| s = torch.cuda.Stream() | |
| s.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(s): | |
| for _ in range(2): | |
| _ = model( | |
| hidden_states=latents, | |
| encoder_hidden_states=encoder_hidden_states, | |
| conditioning=conditioning, | |
| image_rotary_emb=image_rotary_emb, | |
| dt=dt, | |
| ) | |
| torch.cuda.current_stream().wait_stream(s) | |
| print("Capturing CUDAGraph") | |
| static_latents = latents.clone() | |
| static_conditioning = conditioning.clone() | |
| static_dt = dt.clone() | |
| graph = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(graph): | |
| static_x = model( | |
| hidden_states=static_latents, | |
| encoder_hidden_states=encoder_hidden_states, | |
| conditioning=static_conditioning, | |
| image_rotary_emb=image_rotary_emb, | |
| dt=static_dt, | |
| ) | |
| return graph, static_latents, static_conditioning, static_dt, static_x | |
| @torch.inference_mode() | |
| def main( | |
| model_id: str, | |
| prompt: str, | |
| height: int, | |
| width: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| compile_mode: str, | |
| output_file: str, | |
| cache_dir: Optional[str], | |
| enable_cudagraph: bool, | |
| enable_prompt_length_bucketing: bool, | |
| enable_profiling: bool, | |
| working_dir: str, | |
| seed: int, | |
| ): | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| # Load the model components | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype | |
| ) | |
| text_encoder_2 = T5EncoderModel.from_pretrained( | |
| model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=dtype | |
| ) | |
| tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) | |
| tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) | |
| vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype) | |
| image_processor = VaeImageProcessor( | |
| vae_scale_factor=SPATIAL_COMPRESSION_RATIO * PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR | |
| ) | |
| fuse_qkv_(transformer) | |
| fuse_adaln_linear_(transformer) | |
| [x.to(device) for x in (transformer, text_encoder, text_encoder_2, vae)] | |
| vae.to(memory_format=torch.channels_last) | |
| if compile_mode != "none": | |
| transformer = torch.compile(transformer, mode=compile_mode, fullgraph=True, dynamic=True) | |
| # text_encoder = torch.compile(text_encoder, mode="default", fullgraph=True, dynamic=True) | |
| # text_encoder_2 = torch.compile(text_encoder_2, mode="default", fullgraph=True, dynamic=True) | |
| # We don't compile the VAE due to the implementation calling into non-traceable code paths | |
| # vae.decode = torch.compile(vae.decode, mode="default", fullgraph=True, dynamic=True) | |
| # Latent, text, guidance and timestep conditioning preparation | |
| batch_size = 1 | |
| patch_size = transformer.config.patch_size | |
| latent_height = height // (SPATIAL_COMPRESSION_RATIO * patch_size) // PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR | |
| latent_width = width // (SPATIAL_COMPRESSION_RATIO * patch_size) // PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| guidance_embeds, timestep_embeds = precompute_embeds(transformer, device, dtype, working_dir) | |
| latents = torch.randn( | |
| (batch_size, latent_height * latent_width, transformer.config.in_channels), | |
| dtype=dtype, | |
| device=device, | |
| generator=generator, | |
| ) | |
| pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, device, dtype) | |
| encoder_hidden_states, num_text_tokens = prepare_t5_embeddings( | |
| text_encoder_2, tokenizer_2, prompt, device, dtype, T5_SEQUENCE_LENGTH, enable_prompt_length_bucketing | |
| ) | |
| # <precompute> | |
| guidance_conditioning = guidance_embeds[f"{guidance_scale:.1f}"] | |
| sigmas, timestep_conditioning = timestep_embeds[num_inference_steps] | |
| pooled_projections = transformer.time_text_embed.text_embedder(pooled_projections) | |
| encoder_hidden_states = transformer.context_embedder(encoder_hidden_states) | |
| # </precompute> | |
| img_ids = prepare_latent_image_ids(latent_height, latent_width, device=device, dtype=dtype) | |
| txt_ids = torch.zeros(num_text_tokens, 3).to(device=device, dtype=dtype) | |
| if cp_options.mesh is not None: | |
| # Note: clone seems to be a must here otherwise there is a recompilation related to storage offsets (which | |
| # tells you to use torch._dynamo.decorators.mark_unbacked) /shrug | |
| img_ids = EquipartitionSharder.shard(img_ids, dim=0, mesh=cp_options._flattened_mesh).clone() | |
| txt_ids = EquipartitionSharder.shard(txt_ids, dim=0, mesh=cp_options._flattened_mesh).clone() | |
| latents = EquipartitionSharder.shard(latents, dim=1, mesh=cp_options._flattened_mesh).clone() | |
| encoder_hidden_states = EquipartitionSharder.shard(encoder_hidden_states, dim=1, mesh=cp_options._flattened_mesh).clone() | |
| ids = torch.cat([txt_ids, img_ids], dim=0).float() | |
| image_rotary_emb = transformer.pos_embed(ids) | |
| dt = sigmas[1:] - sigmas[:-1] | |
| print("Warming up the model") | |
| for _ in range(2): | |
| conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections) | |
| _ = transformer( | |
| hidden_states=latents, | |
| encoder_hidden_states=encoder_hidden_states, | |
| conditioning=conditioning, | |
| image_rotary_emb=image_rotary_emb, | |
| dt=dt[0], | |
| ) | |
| torch.cuda.synchronize() | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| context = ( | |
| profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True) | |
| if enable_profiling | |
| else contextlib.nullcontext() | |
| ) | |
| if not enable_cudagraph: | |
| with context as ctx: | |
| start_event.record() | |
| for i in range(num_inference_steps): | |
| conditioning = pointwise_add3_silu(timestep_conditioning[i, :], guidance_conditioning, pooled_projections) | |
| latents = transformer( | |
| hidden_states=latents, | |
| encoder_hidden_states=encoder_hidden_states, | |
| conditioning=conditioning, | |
| image_rotary_emb=image_rotary_emb, | |
| dt=dt[i], | |
| ) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| else: | |
| conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections) | |
| graph, static_latents, static_conditioning, static_dt, static_x = capture_cudagraph( | |
| transformer, | |
| latents, | |
| encoder_hidden_states, | |
| conditioning, | |
| image_rotary_emb, | |
| dt[0], | |
| ) | |
| with context as ctx: | |
| start_event.record() | |
| static_x.copy_(latents) | |
| for i in range(num_inference_steps): | |
| conditioning = pointwise_add3_silu(timestep_conditioning[i, :], guidance_conditioning, pooled_projections) | |
| torch._foreach_copy_( | |
| (static_latents, static_conditioning, static_dt), | |
| (static_x, conditioning, dt[i]), | |
| non_blocking=True, | |
| ) | |
| graph.replay() | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| latents = static_x | |
| total_time = start_event.elapsed_time(end_event) / 1000.0 | |
| if cp_options.mesh is not None: | |
| latents = EquipartitionSharder.unshard(latents, dim=1, mesh=cp_options._flattened_mesh) | |
| if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0: | |
| print(f"time: {total_time:.2f}s") | |
| if enable_profiling: | |
| print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) | |
| ctx.export_chrome_trace("dump_benchmark_flux.json") | |
| latents = latents.reshape(batch_size, latent_height, latent_width, -1, 2, 2) | |
| latents = latents.permute(0, 3, 1, 4, 2, 5) | |
| latents = latents.flatten(4, 5).flatten(2, 3) | |
| latents = latents.to(memory_format=torch.channels_last) | |
| latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor | |
| image = vae.decode(latents, return_dict=False)[0] | |
| image = image_processor.postprocess(image, output_type="pil")[0] | |
| image.save(output_file) | |
| class EquipartitionSharder: | |
| @classmethod | |
| def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| assert tensor.size()[dim] % mesh.size() == 0 | |
| # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) | |
| # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] | |
| return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())] | |
| @classmethod | |
| def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| tensor = tensor.contiguous() | |
| tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group()) | |
| return tensor | |
| # Reference: | |
| # - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827 | |
| # - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246 | |
| # For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method): | |
| def _wait_tensor(tensor): | |
| if isinstance(tensor, funcol.AsyncCollectiveTensor): | |
| tensor = tensor.wait() | |
| return tensor | |
| def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: | |
| shape = x.shape | |
| # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization | |
| # to benchmark triton codegen fails somewhere: | |
| # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3') | |
| # ValueError: Tensors must be contiguous | |
| x = x.flatten() | |
| x = funcol.all_to_all_single(x, None, None, group) | |
| x = x.reshape(shape) | |
| x = _wait_tensor(x) | |
| return x | |
| def _templated_ring_attention(query, key, value): | |
| ring_mesh = cp_options.mesh["ring"] | |
| rank = cp_options._ring_local_rank | |
| world_size = cp_options.ring_degree | |
| if world_size == 1: | |
| return cp_options.attention_op(query, key, value) | |
| next_rank = (rank + 1) % world_size | |
| prev_out = prev_lse = None | |
| kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() | |
| kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) | |
| kv_buffer = kv_buffer.chunk(world_size) | |
| for i in range(world_size): | |
| if i > 0: | |
| kv = kv_buffer[next_rank] | |
| key = kv[:key.numel()].reshape_as(key) | |
| value = kv[key.numel():].reshape_as(value) | |
| next_rank = (next_rank + 1) % world_size | |
| out, lse = cp_options.attention_op(query, key, value) | |
| if cp_options.convert_to_fp32: | |
| out = out.to(torch.float32) | |
| lse = lse.to(torch.float32) | |
| lse = lse.unsqueeze(-1) | |
| if prev_out is not None: | |
| out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) | |
| lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) | |
| prev_out = out | |
| prev_lse = lse | |
| out = out.to(query.dtype) | |
| lse = lse.squeeze(-1) | |
| return out, lse | |
| def _templated_ulysses_attention(query, key, value, *, return_lse: bool = False): | |
| ulysses_mesh = cp_options.mesh["ulysses"] | |
| world_size = cp_options.ulysses_degree | |
| group = ulysses_mesh.get_group() | |
| if world_size == 1: | |
| return cp_options.attention_op(query, key, value) | |
| B, S_LOCAL, H, D = query.shape | |
| H_LOCAL = H // world_size | |
| query, key, value = ( | |
| x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).clone() | |
| for x in (query, key, value) | |
| ) | |
| query, key, value = ( | |
| _all_to_all_single(x, group) | |
| for x in (query, key, value) | |
| ) | |
| query, key, value = ( | |
| x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() | |
| for x in (query, key, value) | |
| ) | |
| out, lse = cp_options.attention_op(query, key, value) | |
| out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() | |
| out = _all_to_all_single(out, group) | |
| out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() | |
| if return_lse: | |
| lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() | |
| lse = _all_to_all_single(lse, group) | |
| lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() | |
| else: | |
| lse = None | |
| return out, lse | |
| # TODO: currently produces incorrect results (for example, with CP=4, ring=2, ulysses=2, half output is our expected image, | |
| # and other half is some completely different) | |
| # def _templated_unified_attention(query, key, value): | |
| # ulysses_mesh = cp_options.mesh["ulysses"] | |
| # ulysses_size = ulysses_mesh.size() | |
| # ulysses_group = ulysses_mesh.get_group() | |
| # B, S_LOCAL, H, D = query.shape | |
| # H_LOCAL = H // ulysses_size | |
| # query, key, value = ( | |
| # x.reshape(B, S_LOCAL, ulysses_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
| # for x in (query, key, value) | |
| # ) | |
| # query, key, value = ( | |
| # wait_tensor(funcol.all_to_all_single(x, None, None, group=ulysses_group)) | |
| # for x in (query, key, value) | |
| # ) | |
| # query, key, value = ( | |
| # x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() | |
| # for x in (query, key, value) | |
| # ) | |
| # out, lse = _templated_ring_attention(query, key, value) | |
| # out = out.reshape(B, ulysses_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() | |
| # lse = lse.reshape(B, ulysses_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() | |
| # out = wait_tensor(funcol.all_to_all_single(out, None, None, group=ulysses_group)) | |
| # lse = wait_tensor(funcol.all_to_all_single(lse, None, None, group=ulysses_group)) | |
| # out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() | |
| # lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() | |
| # return out, lse | |
| # For fullgraph=True tracing to be compatible | |
| @torch.library.custom_op("flash_attn_3::_flash_attn_forward_original", mutates_args=(), device_types="cuda") | |
| def _wrapped_flash_attn_3_original(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| out, lse = flash_attn_3_func(query, key, value) | |
| lse = lse.permute(0, 2, 1) | |
| return out, lse | |
| @torch.library.register_fake("flash_attn_3::_flash_attn_forward_original") | |
| def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch_size, seq_len, num_heads, head_dim = query.shape | |
| lse_shape = (batch_size, seq_len, num_heads) | |
| return torch.empty_like(query), query.new_empty(lse_shape) | |
| @torch.library.custom_op("flash_attn_3::_flash_attn_forward_hf", mutates_args=(), device_types="cuda") | |
| def _wrapped_flash_attn_3_hf(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
| out, lse = flash_attn_3_hf.flash_attn_func(query, key, value, causal=False) | |
| return out | |
| @torch.library.register_fake("flash_attn_3::_flash_attn_forward_hf") | |
| def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
| return torch.empty_like(query) | |
| def _attention_torch_cudnn(query, key, value): | |
| query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) | |
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
| torch.ops.aten._scaled_dot_product_cudnn_attention( | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=None, | |
| compute_log_sumexp=True, | |
| ) | |
| ) | |
| out = out.transpose(1, 2).contiguous() | |
| lse = lse.transpose(1, 2).contiguous() | |
| return out, lse | |
| def _attention_flash_attn_2(query, key, value): | |
| out, lse, _ = flash_attn_func(query, key, value, return_attn_probs=True) | |
| lse = lse.permute(0, 2, 1) | |
| return out, lse | |
| def _attention_flash_attn_3_original(query, key, value): | |
| out = _wrapped_flash_attn_3_original(query, key, value) | |
| return out | |
| def _attention_flash_attn_3_hf(query, key, value): | |
| out = _wrapped_flash_attn_3_hf(query, key, value) | |
| return out | |
| def _download_hf_flash_attn_3(): | |
| global flash_attn_3_hf | |
| flash_attn_3_hf = get_kernel("kernels-community/flash-attn3") | |
| def get_args(): | |
| DEFAULT_MODEL_ID = "black-forest-labs/FLUX.1-dev" | |
| DEFAULT_PROMPT = "The King of Hearts card transforms into a 3D hologram that appears to be made of cosmic energy. As the King emerges, stars and galaxies swirl around him, creating a sense of traveling through the universe. The King's attire is adorned with celestial patterns, and his crown is a glowing star cluster. The hologram floats in front of you, with the background shifting through different cosmic scenes, from nebulae to black holes. Atmosphere: Perfect for space-themed events, science fiction conventions, or futuristic tech expos." | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_id", type=str, default=DEFAULT_MODEL_ID) | |
| parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT) | |
| parser.add_argument("--height", type=int, default=1024) | |
| parser.add_argument("--width", type=int, default=1024) | |
| parser.add_argument("--num_inference_steps", type=int, default=28) | |
| parser.add_argument("--guidance_scale", type=float, default=4.0) | |
| parser.add_argument( | |
| "--compile_mode", | |
| type=str, | |
| default="none", | |
| choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], | |
| ) | |
| parser.add_argument("--output_file", type=str, default="output.png") | |
| parser.add_argument("--cache_dir", type=str, default=None) | |
| parser.add_argument("--enable_fp32_rope", action="store_true") | |
| parser.add_argument("--enable_cudagraph", action="store_true") | |
| parser.add_argument("--enable_prompt_length_bucketing", action="store_true") | |
| parser.add_argument("--attention_provider", type=str, default="cudnn", choices=["cudnn", "fa2", "fa3", "fa3_original"]) | |
| parser.add_argument("--ring_degree", type=int, default=1) | |
| parser.add_argument("--ulysses_degree", type=int, default=1) | |
| parser.add_argument("--enable_profiling", action="store_true") | |
| parser.add_argument("--enable_cuda_stream", action="store_true") | |
| parser.add_argument("--disable_tf32", action="store_true") | |
| parser.add_argument("--disable_flags", action="store_true") | |
| parser.add_argument("--working_dir", type=str, default="/tmp/flux_precomputation") | |
| parser.add_argument("--seed", type=int, default=31337) | |
| args = parser.parse_args() | |
| return args | |
| def setup_config(args): | |
| torch.manual_seed(args.seed) | |
| if args.enable_fp32_rope: | |
| global ROPE_PRECISION | |
| ROPE_PRECISION = torch.float32 | |
| if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]: | |
| raise ValueError( | |
| "Only compiled modes 'none', 'default', and 'max-autotune-no-cudagraphs' are supported with CUDAGraphs." | |
| ) | |
| global ATTENTION_OP | |
| if args.attention_provider == "cudnn": | |
| ATTENTION_OP = _attention_torch_cudnn | |
| elif args.attention_provider == "fa2": | |
| ATTENTION_OP = _attention_flash_attn_2 | |
| elif args.attention_provider == "fa3": | |
| _download_hf_flash_attn_3() | |
| ATTENTION_OP = _attention_flash_attn_3_hf | |
| elif args.attention_provider == "fa3_original": | |
| ATTENTION_OP = _attention_flash_attn_3_original | |
| else: | |
| assert False | |
| if args.enable_profiling and args.enable_cudagraph: | |
| torch.profiler._utils._init_for_cuda_graphs() | |
| if not args.disable_tf32: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| if not args.disable_flags: | |
| apply_flags() | |
| global MIN_INFERENCE_STEPS, MAX_INFERENCE_STEPS | |
| if args.num_inference_steps < MIN_INFERENCE_STEPS or args.num_inference_steps > MAX_INFERENCE_STEPS: | |
| raise ValueError(f"`num_inference_steps` must be equal or between {MIN_INFERENCE_STEPS} and {MAX_INFERENCE_STEPS}.") | |
| def setup_distributed(ring_degree: int, ulysses_degree: int, compile_mode: str, enable_cuda_stream: bool): | |
| global ATTENTION_OP | |
| from datetime import timedelta | |
| dist.init_process_group("nccl", timeout=timedelta(seconds=60)) | |
| torch.cuda.set_device(torch.device("cuda", dist.get_rank())) | |
| global STREAM | |
| if args.enable_cuda_stream: | |
| STREAM = torch.cuda.Stream() | |
| if ring_degree * ulysses_degree != dist.get_world_size(): | |
| raise ValueError(f"`ring_degree * ulysses_degree` must equal the world size {dist.get_world_size()}.") | |
| mesh_names = ["ring", "ulysses"] | |
| mesh_dims = [ring_degree, ulysses_degree] | |
| mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names) | |
| cp_options.ring_degree = ring_degree | |
| cp_options.ulysses_degree = ulysses_degree | |
| cp_options.mesh = mesh | |
| cp_options.convert_to_fp32 = True | |
| cp_options.attention_op = ATTENTION_OP | |
| cp_options._flattened_mesh = mesh._flatten() | |
| cp_options._ring_mesh = mesh["ring"] | |
| cp_options._ulysses_mesh = mesh["ulysses"] | |
| cp_options._ring_local_rank = cp_options._ring_mesh.get_local_rank() | |
| cp_options._ulysses_local_rank = cp_options._ulysses_mesh.get_local_rank() | |
| if ring_degree > 1 and ulysses_degree > 1: | |
| raise ValueError("The current implementation is incorrect for unified attention and needs to be fixed.") | |
| # cp_options.mode = "unified" | |
| # ATTENTION_OP = _templated_unified_attention | |
| elif ulysses_degree > 1: | |
| cp_options.mode = "ulysses" | |
| ATTENTION_OP = _templated_ulysses_attention | |
| else: | |
| cp_options.mode = "ring" | |
| ATTENTION_OP = _templated_ring_attention | |
| if compile_mode != "none": | |
| torch._dynamo.config.suppress_errors = True | |
| torch._inductor.config.reorder_for_compute_comm_overlap = True | |
| if __name__ == "__main__": | |
| args = get_args() | |
| try: | |
| setup_config(args) | |
| setup_distributed(args.ring_degree, args.ulysses_degree, args.compile_mode, args.enable_cuda_stream) | |
| main( | |
| model_id=args.model_id, | |
| prompt=args.prompt, | |
| height=args.height, | |
| width=args.width, | |
| num_inference_steps=args.num_inference_steps, | |
| guidance_scale=args.guidance_scale, | |
| compile_mode=args.compile_mode, | |
| output_file=args.output_file, | |
| cache_dir=args.cache_dir, | |
| enable_cudagraph=args.enable_cudagraph, | |
| enable_prompt_length_bucketing=args.enable_prompt_length_bucketing, | |
| enable_profiling=args.enable_profiling, | |
| working_dir=args.working_dir, | |
| seed=args.seed, | |
| ) | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| if dist.is_initialized(): | |
| torch.distributed.breakpoint() | |
| raise | |
| finally: | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
H100: