Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created August 18, 2025 09:40
Show Gist options
  • Select an option

  • Save a-r-r-o-w/d34c53ca9d4f69cd45872931b0f3855d to your computer and use it in GitHub Desktop.

Select an option

Save a-r-r-o-w/d34c53ca9d4f69cd45872931b0f3855d to your computer and use it in GitHub Desktop.
Flux with cuda stream
import argparse
import contextlib
import functools
import pathlib
import math
from dataclasses import dataclass
from typing import Callable, List, Literal, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.profiler._utils
import torch._dynamo.config
import torch._inductor.config
import torch._higher_order_ops.auto_functionalize as af
from torch.profiler import profile, record_function, ProfilerActivity
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import get_1d_rotary_pos_embed
from diffusers.models.cache_utils import CacheMixin
from diffusers.models.embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
)
from diffusers.models.modeling_utils import ModelMixin
from kernels import get_kernel
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
try:
from flash_attn import flash_attn_func
except:
print("Flash Attention 2 not found.")
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
except:
print("Flash Attention 3 not found.")
def apply_flags():
torch._dynamo.config.inline_inbuilt_nn_modules = False
torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.error_on_recompile = True
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.disable_progress = False
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.aggressive_fusion = True
torch._inductor.config.shape_padding = True
torch._inductor.config.triton.enable_persistent_tma_matmul = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cudnn.allow_tf32 = True
af.auto_functionalized_v2._cacheable = True
af.auto_functionalized._cacheable = True
ROPE_PRECISION = torch.bfloat16
SPATIAL_COMPRESSION_RATIO = 8
PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR = 2
T5_SEQUENCE_LENGTH = 512
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
PATCH_SIZE = 1
LATENT_HEIGHT = DEFAULT_HEIGHT // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2
LATENT_WIDTH = DEFAULT_WIDTH // (SPATIAL_COMPRESSION_RATIO * PATCH_SIZE) // 2
SUPPORTED_BUCKET_LENGTHS = list(range(128, 512 + 1, 64))
SUPPORTED_GUIDANCE_SCALES = [i / 2 for i in range(41)] # 0, 0.5, 1.0, ..., 20.0
MIN_INFERENCE_STEPS = 2
MAX_INFERENCE_STEPS = 50
ATTENTION_OP = None
STREAM = None
BASE_IMAGE_SEQ_LEN = 256
MAX_IMAGE_SEQ_LEN = 4096
BASE_SHIFT = 0.5
MAX_SHIFT = 1.15
M = (MAX_SHIFT - BASE_SHIFT) / (MAX_IMAGE_SEQ_LEN - BASE_IMAGE_SEQ_LEN)
B = BASE_SHIFT - M * BASE_IMAGE_SEQ_LEN
@dataclass
class ContextParallelOptions:
ring_degree: int = None
ulysses_degree: int = None
mode: Literal["ring", "ulysses", "unified"] = "ring"
mesh: dist.DeviceMesh | None = None
convert_to_fp32: bool = True
attention_op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]] | None = None
_flattened_mesh: dist.DeviceMesh = None
_ring_mesh: dist.DeviceMesh = None
_ulysses_mesh: dist.DeviceMesh = None
_ring_local_rank: int = None
_ulysses_local_rank: int = None
cp_options = ContextParallelOptions()
class AdaLayerNormContinuous(torch.nn.Module):
def __init__(
self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True
):
super().__init__()
self.linear = torch.nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
self.norm = torch.nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
emb = self.linear(emb)
scale, shift = emb.unsqueeze(1).chunk(2, dim=-1)
x = self.norm(x)
x = torch.addcmul(shift, x, 1 + scale)
return x
class AdaLayerNormZeroSingle(torch.nn.Module):
def __init__(self, embedding_dim: int, bias=True):
super().__init__()
self.linear = torch.nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
x = self.norm(x)
x = torch.addcmul(shift_msa, x, 1 + scale_msa)
return x, gate_msa
class AdaLayerNormZero(torch.nn.Module):
def __init__(self, embedding_dim: int, bias=True):
super().__init__()
self.linear = torch.nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
x = self.norm(x)
x = torch.addcmul(shift_msa, x, 1 + scale_msa)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class Attention(torch.nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
):
super().__init__()
assert qk_norm == "rms_norm", "Flux uses RMSNorm"
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
cos, sin = image_rotary_emb if image_rotary_emb is not None else (None, None)
if encoder_hidden_states is not None:
if STREAM is not None:
STREAM.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(STREAM):
if self.fused_projections:
query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
else:
query_c = self.add_q_proj(encoder_hidden_states)
key_c = self.add_k_proj(encoder_hidden_states)
value_c = self.add_v_proj(encoder_hidden_states)
query_c, key_c, value_c = (
x.unflatten(2, (self.heads, -1)) for x in (query_c, key_c, value_c)
)
query_c = self.norm_added_q(query_c)
key_c = self.norm_added_k(key_c)
else:
if self.fused_projections:
query_c, key_c, value_c = self.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
else:
query_c = self.add_q_proj(encoder_hidden_states)
key_c = self.add_k_proj(encoder_hidden_states)
value_c = self.add_v_proj(encoder_hidden_states)
query_c, key_c, value_c = (
x.unflatten(2, (self.heads, -1)) for x in (query_c, key_c, value_c)
)
query_c = self.norm_added_q(query_c)
key_c = self.norm_added_k(key_c)
if self.fused_projections:
query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1)
else:
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
query, key, value = (x.unflatten(2, (self.heads, -1)) for x in (query, key, value))
query = self.norm_q(query)
key = self.norm_k(key)
if STREAM is not None:
torch.cuda.current_stream().wait_stream(STREAM)
if encoder_hidden_states is not None:
query = torch.cat([query_c, query], dim=1)
key = torch.cat([key_c, key], dim=1)
value = torch.cat([value_c, value], dim=1)
if image_rotary_emb is not None:
x_real, x_imag = query.unflatten(-1, (-1, 2)).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
query = (query.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(query)
x_real, x_imag = key.unflatten(-1, (-1, 2)).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
key = (key.to(ROPE_PRECISION) * cos + x_rotated.to(ROPE_PRECISION) * sin).type_as(key)
hidden_states, lse = ATTENTION_OP(query, key, value)
hidden_states = hidden_states.flatten(2)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = torch.split_with_sizes(
hidden_states,
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]],
dim=1,
)
hidden_states = self.to_out[0](hidden_states)
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
return hidden_states
@torch.no_grad()
def fuse_projections(self):
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = torch.nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = torch.nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
for layer in ("to_q", "to_k", "to_v", "to_added_q", "to_added_k", "to_added_v"):
if hasattr(self, layer):
module = getattr(self, layer)
module.to("meta")
delattr(self, layer)
self.fused_projections = True
class FluxPosEmbed(torch.nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
ids[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=ROPE_PRECISION,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None]
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device, dtype=ROPE_PRECISION)[None, :, None]
return freqs_cos, freqs_sin
class FluxSingleTransformerBlock(torch.nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = torch.nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = torch.nn.GELU(approximate="tanh")
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
self.proj_out = torch.nn.Linear(dim + self.mlp_hidden_dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
if STREAM is not None:
STREAM.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(STREAM):
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
else:
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if STREAM is not None:
torch.cuda.current_stream().wait_stream(STREAM)
attn_mlp_hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
proj_out = self.proj_out(attn_mlp_hidden_states)
hidden_states = torch.addcmul(hidden_states, gate, proj_out)
return hidden_states
class FluxTransformerBlock(torch.nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = Attention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb, temb_context = temb.chunk(2, dim=-1)
if STREAM is not None:
STREAM.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(STREAM):
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb_context
)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb_context
)
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if STREAM is not None:
torch.cuda.current_stream().wait_stream(STREAM)
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if STREAM is not None:
STREAM.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(STREAM):
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output)
hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output)
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp)
ff_output = self.ff(norm_hidden_states)
hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output)
torch.cuda.current_stream().wait_stream(STREAM)
else:
hidden_states = torch.addcmul(hidden_states, gate_msa, attn_output)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_msa, context_attn_output)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = torch.addcmul(shift_mlp, norm_hidden_states, 1 + scale_mlp)
norm_encoder_hidden_states = torch.addcmul(c_shift_mlp, norm_encoder_hidden_states, 1 + c_scale_mlp)
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = torch.addcmul(hidden_states, gate_mlp, ff_output)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, c_gate_mlp, context_ff_output)
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = torch.nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = torch.nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
)
self.single_transformer_blocks = torch.nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = torch.nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
conditioning: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
dt: torch.Tensor,
) -> torch.Tensor:
x_t = hidden_states
hidden_states = self.x_embedder(hidden_states)
if STREAM is not None:
STREAM.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(STREAM):
adaln_linear_dual_stream_states = self.adaln_linear(conditioning).unsqueeze(1).chunk(self.config.num_layers, dim=-1)
else:
adaln_linear_dual_stream_states = self.adaln_linear(conditioning).unsqueeze(1).chunk(self.config.num_layers, dim=-1)
adaln_linear_single_stream_states = self.adaln_linear_single(conditioning).unsqueeze(1).chunk(self.config.num_single_layers, dim=-1)
if STREAM is not None:
torch.cuda.current_stream().wait_stream(STREAM)
for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_linear_dual_stream_states[i],
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for i, block in enumerate(self.single_transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
temb=adaln_linear_single_stream_states[i],
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, conditioning)
velocity = self.proj_out(hidden_states)
x = x_t + dt * velocity
return x
@torch.no_grad()
def fuse_qkv_(model: FluxTransformer2DModel) -> FluxTransformer2DModel:
for submodule in model.modules():
if not isinstance(submodule, Attention):
continue
submodule.fuse_projections()
@torch.no_grad()
def fuse_adaln_linear_(model: FluxTransformer2DModel) -> FluxTransformer2DModel:
adaln_linear_weights = []
adaln_linear_biases = []
for block in model.transformer_blocks:
adaln_linear_weights.append(block.norm1.linear.weight.data.clone())
adaln_linear_weights.append(block.norm1_context.linear.weight.data.clone())
adaln_linear_biases.append(block.norm1.linear.bias.data.clone())
adaln_linear_biases.append(block.norm1_context.linear.bias.data.clone())
block.norm1.linear.to("meta")
block.norm1_context.linear.to("meta")
del block.norm1.linear, block.norm1_context.linear
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0)
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0)
in_features = adaln_linear_weights.shape[1]
out_features = adaln_linear_weights.shape[0]
model.adaln_linear = torch.nn.Linear(
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype
)
model.adaln_linear.weight.copy_(adaln_linear_weights)
model.adaln_linear.bias.copy_(adaln_linear_biases)
adaln_linear_weights = []
adaln_linear_biases = []
for block in model.single_transformer_blocks:
adaln_linear_weights.append(block.norm.linear.weight.data.clone())
adaln_linear_biases.append(block.norm.linear.bias.data.clone())
block.norm.linear.to("meta")
del block.norm.linear
adaln_linear_weights = torch.cat(adaln_linear_weights, dim=0)
adaln_linear_biases = torch.cat(adaln_linear_biases, dim=0)
in_features = adaln_linear_weights.shape[1]
out_features = adaln_linear_weights.shape[0]
model.adaln_linear_single = torch.nn.Linear(
in_features, out_features, bias=True, device=adaln_linear_weights.device, dtype=adaln_linear_weights.dtype
)
model.adaln_linear_single.weight.copy_(adaln_linear_weights)
model.adaln_linear_single.bias.copy_(adaln_linear_biases)
def prepare_clip_embeddings(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
prompt: str,
device: torch.device,
dtype: torch.dtype,
max_length: int = 77,
) -> torch.Tensor:
prompt = [prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
prompt_embeds = prompt_embeds.pooler_output.to(dtype)
return prompt_embeds
def prepare_t5_embeddings(
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
prompt: str,
device: torch.device,
dtype: torch.dtype,
max_length: int = 512,
enable_prompt_length_bucketing: bool = False,
) -> torch.Tensor:
prompt = [prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if enable_prompt_length_bucketing:
attention_mask = text_inputs.attention_mask
num_text_tokens = attention_mask.sum(dim=1).max().item()
max_length = min(
SUPPORTED_BUCKET_LENGTHS, key=lambda x: abs(x - num_text_tokens) if x >= num_text_tokens else float("inf")
)
text_input_ids = text_input_ids[:, :max_length]
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = prompt_embeds.to(dtype)
return prompt_embeds, max_length
@functools.lru_cache(maxsize=8)
def prepare_latent_image_ids(height: int, width: int, device: torch.device, dtype: torch.dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
).contiguous()
return latent_image_ids.to(device=device, dtype=dtype)
def precompute_guidance_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype):
embeds = {}
for guidance_scale in SUPPORTED_GUIDANCE_SCALES:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = transformer.time_text_embed.guidance_embedder(
transformer.time_text_embed.time_proj(guidance * 1000.0).to(dtype)
)
embeds[f"{guidance_scale:.1f}"] = guidance
return embeds
def precompute_timestep_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype):
embeds = {}
image_seq_len = LATENT_HEIGHT * LATENT_WIDTH
for num_inference_steps in range(MIN_INFERENCE_STEPS, MAX_INFERENCE_STEPS + 1):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = B + image_seq_len * M
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** 1.0)
sigmas = torch.from_numpy(sigmas).to(device, dtype=torch.float32)
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)])
timesteps = (sigmas * 1000.0).to(dtype)
temb = transformer.time_text_embed.time_proj(timesteps)
temb = transformer.time_text_embed.timestep_embedder(temb.to(dtype))
embeds[num_inference_steps] = (sigmas, temb)
return embeds
def precompute_embeds(transformer: FluxTransformer2DModel, device: torch.device, dtype: torch.dtype, save_dir: str):
save_dir = pathlib.Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
guidance_path = save_dir / "guidance_embeds.pt"
timestep_path = save_dir / "timestep_embeds.pt"
if guidance_path.exists():
guidance_embeds = torch.load(guidance_path, map_location=device, weights_only=True)
print(f'Loaded precomputed guidance embeddings from "{guidance_path.as_posix()}"')
else:
guidance_embeds = precompute_guidance_embeds(transformer, device, dtype)
if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0:
torch.save(guidance_embeds, guidance_path.as_posix())
print(f'Precomputed guidance embeddings saved to "{save_dir.as_posix()}"')
if timestep_path.exists():
timestep_embeds = torch.load(timestep_path, map_location=device, weights_only=True)
print(f'Loaded precomputed timestep embeddings from "{timestep_path.as_posix()}"')
else:
timestep_embeds = precompute_timestep_embeds(transformer, device, dtype)
if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0:
torch.save(timestep_embeds, timestep_path.as_posix())
print(f'Precomputed timestep embeddings saved to "{save_dir.as_posix()}"')
return guidance_embeds, timestep_embeds
@torch.compile
def pointwise_add3_silu(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(x + y + z)
def capture_cudagraph(
model: FluxTransformer2DModel,
latents: torch.Tensor,
encoder_hidden_states: torch.Tensor,
conditioning: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
dt: torch.Tensor,
):
print("Warming up CUDAGraph capture")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(2):
_ = model(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
conditioning=conditioning,
image_rotary_emb=image_rotary_emb,
dt=dt,
)
torch.cuda.current_stream().wait_stream(s)
print("Capturing CUDAGraph")
static_latents = latents.clone()
static_conditioning = conditioning.clone()
static_dt = dt.clone()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_x = model(
hidden_states=static_latents,
encoder_hidden_states=encoder_hidden_states,
conditioning=static_conditioning,
image_rotary_emb=image_rotary_emb,
dt=static_dt,
)
return graph, static_latents, static_conditioning, static_dt, static_x
@torch.inference_mode()
def main(
model_id: str,
prompt: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
compile_mode: str,
output_file: str,
cache_dir: Optional[str],
enable_cudagraph: bool,
enable_prompt_length_bucketing: bool,
enable_profiling: bool,
working_dir: str,
seed: int,
):
device = "cuda"
dtype = torch.bfloat16
# Load the model components
transformer = FluxTransformer2DModel.from_pretrained(
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype
)
text_encoder = CLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype
)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=dtype
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype)
image_processor = VaeImageProcessor(
vae_scale_factor=SPATIAL_COMPRESSION_RATIO * PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR
)
fuse_qkv_(transformer)
fuse_adaln_linear_(transformer)
[x.to(device) for x in (transformer, text_encoder, text_encoder_2, vae)]
vae.to(memory_format=torch.channels_last)
if compile_mode != "none":
transformer = torch.compile(transformer, mode=compile_mode, fullgraph=True, dynamic=True)
# text_encoder = torch.compile(text_encoder, mode="default", fullgraph=True, dynamic=True)
# text_encoder_2 = torch.compile(text_encoder_2, mode="default", fullgraph=True, dynamic=True)
# We don't compile the VAE due to the implementation calling into non-traceable code paths
# vae.decode = torch.compile(vae.decode, mode="default", fullgraph=True, dynamic=True)
# Latent, text, guidance and timestep conditioning preparation
batch_size = 1
patch_size = transformer.config.patch_size
latent_height = height // (SPATIAL_COMPRESSION_RATIO * patch_size) // PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR
latent_width = width // (SPATIAL_COMPRESSION_RATIO * patch_size) // PIXEL_UNSHUFFLING_DOWNSAMPLING_FACTOR
generator = torch.Generator(device=device).manual_seed(seed)
guidance_embeds, timestep_embeds = precompute_embeds(transformer, device, dtype, working_dir)
latents = torch.randn(
(batch_size, latent_height * latent_width, transformer.config.in_channels),
dtype=dtype,
device=device,
generator=generator,
)
pooled_projections = prepare_clip_embeddings(text_encoder, tokenizer, prompt, device, dtype)
encoder_hidden_states, num_text_tokens = prepare_t5_embeddings(
text_encoder_2, tokenizer_2, prompt, device, dtype, T5_SEQUENCE_LENGTH, enable_prompt_length_bucketing
)
# <precompute>
guidance_conditioning = guidance_embeds[f"{guidance_scale:.1f}"]
sigmas, timestep_conditioning = timestep_embeds[num_inference_steps]
pooled_projections = transformer.time_text_embed.text_embedder(pooled_projections)
encoder_hidden_states = transformer.context_embedder(encoder_hidden_states)
# </precompute>
img_ids = prepare_latent_image_ids(latent_height, latent_width, device=device, dtype=dtype)
txt_ids = torch.zeros(num_text_tokens, 3).to(device=device, dtype=dtype)
if cp_options.mesh is not None:
# Note: clone seems to be a must here otherwise there is a recompilation related to storage offsets (which
# tells you to use torch._dynamo.decorators.mark_unbacked) /shrug
img_ids = EquipartitionSharder.shard(img_ids, dim=0, mesh=cp_options._flattened_mesh).clone()
txt_ids = EquipartitionSharder.shard(txt_ids, dim=0, mesh=cp_options._flattened_mesh).clone()
latents = EquipartitionSharder.shard(latents, dim=1, mesh=cp_options._flattened_mesh).clone()
encoder_hidden_states = EquipartitionSharder.shard(encoder_hidden_states, dim=1, mesh=cp_options._flattened_mesh).clone()
ids = torch.cat([txt_ids, img_ids], dim=0).float()
image_rotary_emb = transformer.pos_embed(ids)
dt = sigmas[1:] - sigmas[:-1]
print("Warming up the model")
for _ in range(2):
conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections)
_ = transformer(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
conditioning=conditioning,
image_rotary_emb=image_rotary_emb,
dt=dt[0],
)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
context = (
profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True)
if enable_profiling
else contextlib.nullcontext()
)
if not enable_cudagraph:
with context as ctx:
start_event.record()
for i in range(num_inference_steps):
conditioning = pointwise_add3_silu(timestep_conditioning[i, :], guidance_conditioning, pooled_projections)
latents = transformer(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
conditioning=conditioning,
image_rotary_emb=image_rotary_emb,
dt=dt[i],
)
end_event.record()
torch.cuda.synchronize()
else:
conditioning = pointwise_add3_silu(timestep_conditioning[0, :], guidance_conditioning, pooled_projections)
graph, static_latents, static_conditioning, static_dt, static_x = capture_cudagraph(
transformer,
latents,
encoder_hidden_states,
conditioning,
image_rotary_emb,
dt[0],
)
with context as ctx:
start_event.record()
static_x.copy_(latents)
for i in range(num_inference_steps):
conditioning = pointwise_add3_silu(timestep_conditioning[i, :], guidance_conditioning, pooled_projections)
torch._foreach_copy_(
(static_latents, static_conditioning, static_dt),
(static_x, conditioning, dt[i]),
non_blocking=True,
)
graph.replay()
end_event.record()
torch.cuda.synchronize()
latents = static_x
total_time = start_event.elapsed_time(end_event) / 1000.0
if cp_options.mesh is not None:
latents = EquipartitionSharder.unshard(latents, dim=1, mesh=cp_options._flattened_mesh)
if cp_options.mesh is None or cp_options.mesh._flatten().get_local_rank() == 0:
print(f"time: {total_time:.2f}s")
if enable_profiling:
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
ctx.export_chrome_trace("dump_benchmark_flux.json")
latents = latents.reshape(batch_size, latent_height, latent_width, -1, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.flatten(4, 5).flatten(2, 3)
latents = latents.to(memory_format=torch.channels_last)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")[0]
image.save(output_file)
class EquipartitionSharder:
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
assert tensor.size()[dim] % mesh.size() == 0
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
return tensor
# Reference:
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
def _wait_tensor(tensor):
if isinstance(tensor, funcol.AsyncCollectiveTensor):
tensor = tensor.wait()
return tensor
def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
shape = x.shape
# HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization
# to benchmark triton codegen fails somewhere:
# buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
# ValueError: Tensors must be contiguous
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
x = x.reshape(shape)
x = _wait_tensor(x)
return x
def _templated_ring_attention(query, key, value):
ring_mesh = cp_options.mesh["ring"]
rank = cp_options._ring_local_rank
world_size = cp_options.ring_degree
if world_size == 1:
return cp_options.attention_op(query, key, value)
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
kv_buffer = kv_buffer.chunk(world_size)
for i in range(world_size):
if i > 0:
kv = kv_buffer[next_rank]
key = kv[:key.numel()].reshape_as(key)
value = kv[key.numel():].reshape_as(value)
next_rank = (next_rank + 1) % world_size
out, lse = cp_options.attention_op(query, key, value)
if cp_options.convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse
out = out.to(query.dtype)
lse = lse.squeeze(-1)
return out, lse
def _templated_ulysses_attention(query, key, value, *, return_lse: bool = False):
ulysses_mesh = cp_options.mesh["ulysses"]
world_size = cp_options.ulysses_degree
group = ulysses_mesh.get_group()
if world_size == 1:
return cp_options.attention_op(query, key, value)
B, S_LOCAL, H, D = query.shape
H_LOCAL = H // world_size
query, key, value = (
x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).clone()
for x in (query, key, value)
)
query, key, value = (
_all_to_all_single(x, group)
for x in (query, key, value)
)
query, key, value = (
x.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
for x in (query, key, value)
)
out, lse = cp_options.attention_op(query, key, value)
out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = _all_to_all_single(out, group)
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
if return_lse:
lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
lse = _all_to_all_single(lse, group)
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
else:
lse = None
return out, lse
# TODO: currently produces incorrect results (for example, with CP=4, ring=2, ulysses=2, half output is our expected image,
# and other half is some completely different)
# def _templated_unified_attention(query, key, value):
# ulysses_mesh = cp_options.mesh["ulysses"]
# ulysses_size = ulysses_mesh.size()
# ulysses_group = ulysses_mesh.get_group()
# B, S_LOCAL, H, D = query.shape
# H_LOCAL = H // ulysses_size
# query, key, value = (
# x.reshape(B, S_LOCAL, ulysses_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
# for x in (query, key, value)
# )
# query, key, value = (
# wait_tensor(funcol.all_to_all_single(x, None, None, group=ulysses_group))
# for x in (query, key, value)
# )
# query, key, value = (
# x.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
# for x in (query, key, value)
# )
# out, lse = _templated_ring_attention(query, key, value)
# out = out.reshape(B, ulysses_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
# lse = lse.reshape(B, ulysses_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
# out = wait_tensor(funcol.all_to_all_single(out, None, None, group=ulysses_group))
# lse = wait_tensor(funcol.all_to_all_single(lse, None, None, group=ulysses_group))
# out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
# lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
# return out, lse
# For fullgraph=True tracing to be compatible
@torch.library.custom_op("flash_attn_3::_flash_attn_forward_original", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_func(query, key, value)
lse = lse.permute(0, 2, 1)
return out, lse
@torch.library.register_fake("flash_attn_3::_flash_attn_forward_original")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape)
@torch.library.custom_op("flash_attn_3::_flash_attn_forward_hf", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_hf(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
out, lse = flash_attn_3_hf.flash_attn_func(query, key, value, causal=False)
return out
@torch.library.register_fake("flash_attn_3::_flash_attn_forward_hf")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.empty_like(query)
def _attention_torch_cudnn(query, key, value):
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=None,
compute_log_sumexp=True,
)
)
out = out.transpose(1, 2).contiguous()
lse = lse.transpose(1, 2).contiguous()
return out, lse
def _attention_flash_attn_2(query, key, value):
out, lse, _ = flash_attn_func(query, key, value, return_attn_probs=True)
lse = lse.permute(0, 2, 1)
return out, lse
def _attention_flash_attn_3_original(query, key, value):
out = _wrapped_flash_attn_3_original(query, key, value)
return out
def _attention_flash_attn_3_hf(query, key, value):
out = _wrapped_flash_attn_3_hf(query, key, value)
return out
def _download_hf_flash_attn_3():
global flash_attn_3_hf
flash_attn_3_hf = get_kernel("kernels-community/flash-attn3")
def get_args():
DEFAULT_MODEL_ID = "black-forest-labs/FLUX.1-dev"
DEFAULT_PROMPT = "The King of Hearts card transforms into a 3D hologram that appears to be made of cosmic energy. As the King emerges, stars and galaxies swirl around him, creating a sense of traveling through the universe. The King's attire is adorned with celestial patterns, and his crown is a glowing star cluster. The hologram floats in front of you, with the background shifting through different cosmic scenes, from nebulae to black holes. Atmosphere: Perfect for space-themed events, science fiction conventions, or futuristic tech expos."
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=DEFAULT_MODEL_ID)
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT)
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--num_inference_steps", type=int, default=28)
parser.add_argument("--guidance_scale", type=float, default=4.0)
parser.add_argument(
"--compile_mode",
type=str,
default="none",
choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
)
parser.add_argument("--output_file", type=str, default="output.png")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--enable_fp32_rope", action="store_true")
parser.add_argument("--enable_cudagraph", action="store_true")
parser.add_argument("--enable_prompt_length_bucketing", action="store_true")
parser.add_argument("--attention_provider", type=str, default="cudnn", choices=["cudnn", "fa2", "fa3", "fa3_original"])
parser.add_argument("--ring_degree", type=int, default=1)
parser.add_argument("--ulysses_degree", type=int, default=1)
parser.add_argument("--enable_profiling", action="store_true")
parser.add_argument("--enable_cuda_stream", action="store_true")
parser.add_argument("--disable_tf32", action="store_true")
parser.add_argument("--disable_flags", action="store_true")
parser.add_argument("--working_dir", type=str, default="/tmp/flux_precomputation")
parser.add_argument("--seed", type=int, default=31337)
args = parser.parse_args()
return args
def setup_config(args):
torch.manual_seed(args.seed)
if args.enable_fp32_rope:
global ROPE_PRECISION
ROPE_PRECISION = torch.float32
if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]:
raise ValueError(
"Only compiled modes 'none', 'default', and 'max-autotune-no-cudagraphs' are supported with CUDAGraphs."
)
global ATTENTION_OP
if args.attention_provider == "cudnn":
ATTENTION_OP = _attention_torch_cudnn
elif args.attention_provider == "fa2":
ATTENTION_OP = _attention_flash_attn_2
elif args.attention_provider == "fa3":
_download_hf_flash_attn_3()
ATTENTION_OP = _attention_flash_attn_3_hf
elif args.attention_provider == "fa3_original":
ATTENTION_OP = _attention_flash_attn_3_original
else:
assert False
if args.enable_profiling and args.enable_cudagraph:
torch.profiler._utils._init_for_cuda_graphs()
if not args.disable_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if not args.disable_flags:
apply_flags()
global MIN_INFERENCE_STEPS, MAX_INFERENCE_STEPS
if args.num_inference_steps < MIN_INFERENCE_STEPS or args.num_inference_steps > MAX_INFERENCE_STEPS:
raise ValueError(f"`num_inference_steps` must be equal or between {MIN_INFERENCE_STEPS} and {MAX_INFERENCE_STEPS}.")
def setup_distributed(ring_degree: int, ulysses_degree: int, compile_mode: str, enable_cuda_stream: bool):
global ATTENTION_OP
from datetime import timedelta
dist.init_process_group("nccl", timeout=timedelta(seconds=60))
torch.cuda.set_device(torch.device("cuda", dist.get_rank()))
global STREAM
if args.enable_cuda_stream:
STREAM = torch.cuda.Stream()
if ring_degree * ulysses_degree != dist.get_world_size():
raise ValueError(f"`ring_degree * ulysses_degree` must equal the world size {dist.get_world_size()}.")
mesh_names = ["ring", "ulysses"]
mesh_dims = [ring_degree, ulysses_degree]
mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names)
cp_options.ring_degree = ring_degree
cp_options.ulysses_degree = ulysses_degree
cp_options.mesh = mesh
cp_options.convert_to_fp32 = True
cp_options.attention_op = ATTENTION_OP
cp_options._flattened_mesh = mesh._flatten()
cp_options._ring_mesh = mesh["ring"]
cp_options._ulysses_mesh = mesh["ulysses"]
cp_options._ring_local_rank = cp_options._ring_mesh.get_local_rank()
cp_options._ulysses_local_rank = cp_options._ulysses_mesh.get_local_rank()
if ring_degree > 1 and ulysses_degree > 1:
raise ValueError("The current implementation is incorrect for unified attention and needs to be fixed.")
# cp_options.mode = "unified"
# ATTENTION_OP = _templated_unified_attention
elif ulysses_degree > 1:
cp_options.mode = "ulysses"
ATTENTION_OP = _templated_ulysses_attention
else:
cp_options.mode = "ring"
ATTENTION_OP = _templated_ring_attention
if compile_mode != "none":
torch._dynamo.config.suppress_errors = True
torch._inductor.config.reorder_for_compute_comm_overlap = True
if __name__ == "__main__":
args = get_args()
try:
setup_config(args)
setup_distributed(args.ring_degree, args.ulysses_degree, args.compile_mode, args.enable_cuda_stream)
main(
model_id=args.model_id,
prompt=args.prompt,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
compile_mode=args.compile_mode,
output_file=args.output_file,
cache_dir=args.cache_dir,
enable_cudagraph=args.enable_cudagraph,
enable_prompt_length_bucketing=args.enable_prompt_length_bucketing,
enable_profiling=args.enable_profiling,
working_dir=args.working_dir,
seed=args.seed,
)
except Exception as e:
print(f"An error occurred: {e}")
if dist.is_initialized():
torch.distributed.breakpoint()
raise
finally:
if dist.is_initialized():
dist.destroy_process_group()
@a-r-r-o-w
Copy link
Author

H100:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment