Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created August 29, 2025 12:57
Show Gist options
  • Select an option

  • Save a-r-r-o-w/a4a9d889857d1b36a52fadfdc1aca249 to your computer and use it in GitHub Desktop.

Select an option

Save a-r-r-o-w/a4a9d889857d1b36a52fadfdc1aca249 to your computer and use it in GitHub Desktop.
Wan 2.2 5B T2V benchmarks
#!/bin/bash
set -xe
export TORCH_LOGS="recompiles,inductor"
export CUDA_VISIBLE_DEVICES="3,2,1,0"
set_fa_op() {
COMPUTE_CAPABILITY=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | tr -d '.')
if [[ ! "$COMPUTE_CAPABILITY" =~ ^[0-9]+$ ]]; then
echo "Error: Could not determine GPU compute capability"
exit 1
fi
if [[ "$COMPUTE_CAPABILITY" -ge 90 ]]; then
export FA_OP="fa3_original"
elif [[ "$COMPUTE_CAPABILITY" -ge 80 ]]; then
export FA_OP="fa2"
else
echo "Error: GPU compute capability $COMPUTE_CAPABILITY is below SM80, unsupported"
exit 1
fi
}
set_fa_op
# NOTE: for context parallel, we do not use cudagraphs due to OOM issues. further investigation is needed
# MODEL_ID="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
# MODEL_ID="Wan-AI/Wan2.1-T2V-14B-Diffusers"
MODEL_ID="Wan-AI/Wan2.2-TI2V-5B-Diffusers"
PROMPT="A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
NEGATIVE_PROMPT="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
HEIGHT=704
WIDTH=1280
NUM_INFERENCE_STEPS=30
GUIDANCE_SCALE=5.0
OUTPUT_FILE="output.mp4"
# Eager baseline (single GPU)
torchrun --nnodes 1 --nproc_per_node 1 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--output_file "${OUTPUT_FILE}"
# Cudagraph (single GPU)
torchrun --nnodes 1 --nproc_per_node 1 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--enable_cudagraph \
--output_file "${OUTPUT_FILE}"
# + context parallel (ring=2)
torchrun --nnodes 1 --nproc_per_node 2 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--attention_provider "${FA_OP}" \
--ring_degree 2 \
--output_file "${OUTPUT_FILE}"
# + context parallel (ulysses=2)
torchrun --nnodes 1 --nproc_per_node 2 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--attention_provider "${FA_OP}" \
--ulysses_degree 2 \
--output_file "${OUTPUT_FILE}"
# + context parallel (ring=4)
torchrun --nnodes 1 --nproc_per_node 4 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--attention_provider "${FA_OP}" \
--ring_degree 4 \
--output_file "${OUTPUT_FILE}"
# + context parallel (ulysses=4)
torchrun --nnodes 1 --nproc_per_node 4 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--attention_provider "${FA_OP}" \
--ulysses_degree 4 \
--output_file "${OUTPUT_FILE}"
# Cudagraph + compile (single GPU)
torchrun --nnodes 1 --nproc_per_node 1 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--compile_mode default \
--enable_cudagraph \
--output_file "${OUTPUT_FILE}"
# + context parallel + compile (ring=2)
torchrun --nnodes 1 --nproc_per_node 2 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--attention_provider "${FA_OP}" \
--ring_degree 2 \
--compile_mode default \
--output_file "${OUTPUT_FILE}"
# + context parallel + compile (ulysses=2)
torchrun --nnodes 1 --nproc_per_node 2 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--attention_provider "${FA_OP}" \
--ulysses_degree 2 \
--compile_mode default \
--output_file "${OUTPUT_FILE}"
# + context parallel + compile (ring=4)
torchrun --nnodes 1 --nproc_per_node 4 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--attention_provider "${FA_OP}" \
--ring_degree 4 \
--compile_mode default \
--output_file "${OUTPUT_FILE}"
# + context parallel + compile (ulysses=4)
torchrun --nnodes 1 --nproc_per_node 4 dump_wan_5b.py \
--model_id "${MODEL_ID}" \
--prompt "${PROMPT}" \
--height "${HEIGHT}" \
--width "${WIDTH}" \
--num_inference_steps "${NUM_INFERENCE_STEPS}" \
--guidance_scale "${GUIDANCE_SCALE}" \
--attention_provider "${FA_OP}" \
--ulysses_degree 4 \
--compile_mode default \
--output_file "${OUTPUT_FILE}"
import argparse
import contextlib
import math
from dataclasses import dataclass
from typing import Callable, Literal, Optional, Tuple
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.profiler._utils
import torch._dynamo.config
import torch._inductor.config
import torch._higher_order_ops.auto_functionalize as af
from torch.profiler import profile, record_function, ProfilerActivity
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.cache_utils import CacheMixin
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import FP32LayerNorm
from diffusers.video_processor import VideoProcessor
from diffusers.utils import export_to_video
from kernels import get_kernel
from transformers import T5EncoderModel, T5TokenizerFast
try:
from flash_attn import flash_attn_func
except:
print("Flash Attention 2 not found.")
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
except:
print("Flash Attention 3 not found.")
ATTENTION_OP = None
@dataclass
class ContextParallelOptions:
ring_degree: int = None
ulysses_degree: int = None
mode: Literal["ring", "ulysses", "unified"] = "ring"
mesh: dist.DeviceMesh | None = None
convert_to_fp32: bool = True
attention_op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]] | None = None
_flattened_mesh: dist.DeviceMesh = None
_ring_mesh: dist.DeviceMesh = None
_ulysses_mesh: dist.DeviceMesh = None
_ring_local_rank: int = None
_ulysses_local_rank: int = None
cp_options = ContextParallelOptions()
def apply_flags():
torch._dynamo.config.inline_inbuilt_nn_modules = False
torch._dynamo.config.cache_size_limit = 128
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.disable_progress = False
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.aggressive_fusion = True
# torch._inductor.config.reorder_for_compute_comm_overlap = True
torch._inductor.config.shape_padding = True
torch._inductor.config.triton.enable_persistent_tma_matmul = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cudnn.allow_tf32 = True
af.auto_functionalized_v2._cacheable = True
af.auto_functionalized._cacheable = True
class WanAttention(torch.nn.Module):
def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
eps: float = 1e-5,
dropout: float = 0.0,
added_kv_proj_dim: Optional[int] = None,
is_cross_attention: bool = False,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.is_cross_attention = is_cross_attention
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_out = torch.nn.ModuleList([
torch.nn.Linear(self.inner_dim, self.inner_dim, bias=True),
torch.nn.Dropout(dropout)
])
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.add_k_proj = self.add_v_proj = None
if added_kv_proj_dim is not None:
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
self.fused_projections = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if self.add_k_proj is not None:
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
query, key, value = self._get_qkv_projections(hidden_states, encoder_hidden_states)
query = self.norm_q(query)
key = self.norm_k(key)
query, key, value = (x.unflatten(2, (self.heads, -1)) for x in (query, key, value))
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
x = hidden_states.unflatten(-1, (-1, 2))
x1, x2 = x[..., 0], x[..., 1]
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb[0], rotary_emb[1])
key = apply_rotary_emb(key, rotary_emb[0], rotary_emb[1])
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img, value_img = self._get_added_kv_projections(encoder_hidden_states_img)
key_img = self.norm_added_k(key_img)
key_img, value_img = (x.unflatten(2, (self.heads, -1)) for x in (key_img, value_img))
hidden_states_img, _ = ATTENTION_OP(query, key_img, value_img)
hidden_states_img = hidden_states_img.flatten(2, 3).type_as(query)
hidden_states, _ = ATTENTION_OP(query, key, value)
hidden_states = hidden_states.flatten(2, 3).type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = self.to_out[0](hidden_states)
return hidden_states
@torch.no_grad()
def fuse_projections(self):
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_qkv = torch.nn.Linear(in_features, out_features, bias=True)
self.to_qkv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_kv = torch.nn.Linear(in_features, out_features, bias=True)
self.to_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_added_kv = torch.nn.Linear(in_features, out_features, bias=True)
self.to_added_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
self.fused_projections = True
def _get_qkv_projections(self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor]):
if self.fused_projections:
if not self.is_cross_attention:
query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1)
else:
query = self.to_q(hidden_states)
key, value = self.to_kv(encoder_hidden_states).chunk(2, dim=-1)
else:
query = self.to_q(hidden_states)
if self.added_kv_proj_dim is None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
else:
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
return query, key, value
def _get_added_kv_projections(self, encoder_hidden_states_img: torch.Tensor):
if self.fused_projections:
key_img, value_img = self.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
else:
key_img = self.add_k_proj(encoder_hidden_states_img)
value_img = self.add_v_proj(encoder_hidden_states_img)
return key_img, value_img
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = torch.nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
class WanTimeTextImageEmbedding(torch.nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = torch.nn.SiLU()
self.time_proj = torch.nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
temb = self.time_embedder(timestep)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class WanRotaryPosEmbed(torch.nn.Module):
def __init__(
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
freqs_sin = []
for dim in [t_dim, h_dim, w_dim]:
freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
)
freqs_cos.append(freq_cos)
freqs_sin.append(freq_sin)
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
if cp_options.mesh is not None:
freqs_cos = EquipartitionSharder.shard(freqs_cos, dim=1, mesh=cp_options._flattened_mesh)
freqs_sin = EquipartitionSharder.shard(freqs_sin, dim=1, mesh=cp_options._flattened_mesh)
return freqs_cos, freqs_sin
class WanTransformerBlock(torch.nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
):
super().__init__()
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = WanAttention(
dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, is_cross_attention=False,
)
self.attn2 = WanAttention(
dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, added_kv_proj_dim=added_kv_proj_dim, is_cross_attention=True
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else torch.nn.Identity()
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = torch.nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
scale_msa.add_(1)
c_scale_msa.add_(1)
norm_hidden_states = torch.addcmul(shift_msa, self.norm1(hidden_states.float()), scale_msa).type_as(hidden_states)
attn_output = self.attn1(norm_hidden_states, None, rotary_emb)
hidden_states = torch.addcmul(hidden_states.float(), attn_output, gate_msa).type_as(hidden_states)
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None)
hidden_states.add_(attn_output)
norm_hidden_states = torch.addcmul(c_shift_msa, self.norm3(hidden_states.float()), c_scale_msa).type_as(hidden_states)
ff_output = self.ffn(norm_hidden_states)
hidden_states = torch.addcmul(hidden_states.float(), ff_output, c_gate_msa).type_as(hidden_states)
return hidden_states
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
num_layers: int = 40,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
self.blocks = torch.nn.ModuleList(
[
WanTransformerBlock(
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
)
for _ in range(num_layers)
]
)
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = torch.nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = torch.nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if cp_options.mesh is not None:
hidden_states = EquipartitionSharder.shard(hidden_states, dim=1, mesh=cp_options._flattened_mesh)
encoder_hidden_states = EquipartitionSharder.shard(encoder_hidden_states, dim=1, mesh=cp_options._flattened_mesh)
if encoder_hidden_states_image is not None:
if cp_options.mesh is not None:
encoder_hidden_states_image = EquipartitionSharder.shard(encoder_hidden_states_image, dim=1, mesh=cp_options._flattened_mesh)
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
scale.add_(1)
hidden_states = torch.addcmul(shift, self.norm_out(hidden_states.float()), scale).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
if cp_options.mesh is not None:
hidden_states = EquipartitionSharder.unshard(hidden_states, dim=1, mesh=cp_options._flattened_mesh)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
return output
def fuse_qkv_(model: WanTransformer3DModel) -> None:
for submodule in model.modules():
if not isinstance(submodule, WanAttention):
continue
submodule.fuse_projections()
def prepare_t5_embeddings(
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
prompt: str,
device: torch.device,
dtype: torch.dtype,
max_length: int = 512,
) -> torch.Tensor:
prompt = [prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
return prompt_embeds
@torch.inference_mode()
def capture_cudagraph(
model: WanTransformer3DModel,
latents: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
):
print("Warming up CUDAGraph capture")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(2):
_ = model(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
)
torch.cuda.current_stream().wait_stream(s)
print("Capturing CUDAGraph")
static_latents = latents.clone()
static_timestep = timestep.clone()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_x = model(
hidden_states=static_latents,
encoder_hidden_states=encoder_hidden_states,
timestep=static_timestep,
rotary_emb=rotary_emb,
)
return graph, static_latents, static_timestep, static_x
@torch.inference_mode()
def main(
model_id: str,
prompt: str,
negative_prompt: str,
height: int,
width: int,
num_frames: int,
num_inference_steps: int,
guidance_scale: float,
compile_mode: str,
output_file: str,
cache_dir: Optional[str],
enable_cudagraph: bool,
enable_profiling: bool,
seed: int,
):
# Constants/defaults and device/dtype preparation
device = "cuda"
dtype = torch.bfloat16
SPATIAL_COMPRESSION_RATIO = 16
TEMPORAL_COMPRESSION_RATIO = 4
T5_SEQUENCE_LENGTH = 512
# Load the model components
transformer = WanTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype
)
text_encoder = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype
)
tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
vae = AutoencoderKLWan.from_pretrained(
model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype
)
scheduler = UniPCMultistepScheduler.from_pretrained(model_id, subfolder="scheduler", cache_dir=cache_dir)
video_processor = VideoProcessor(vae_scale_factor=SPATIAL_COMPRESSION_RATIO)
fuse_qkv_(transformer)
[x.to(device, dtype=dtype) for x in (transformer, text_encoder, vae)]
if compile_mode != "none":
transformer = torch.compile(transformer, mode=compile_mode, fullgraph=True, dynamic=True)
# text_encoder = torch.compile(text_encoder, mode="default", fullgraph=True, dynamic=True)
# We don't compile the VAE due to the implementation calling into non-traceable code paths
# vae.decode = torch.compile(vae.decode, mode="default", fullgraph=True, dynamic=True)
scheduler.set_timesteps = torch.compile(scheduler.set_timesteps, fullgraph=True, dynamic=False)
# Fails due to CPU tensor usage somewhere in the scheduler.step implementation, which Triton does not support
# scheduler.step = torch.compile(scheduler.step, fullgraph=True, dynamic=False)
# Latent, text and timestep conditioning preparation
batch_size = 1
latent_height = height // SPATIAL_COMPRESSION_RATIO
latent_width = width // SPATIAL_COMPRESSION_RATIO
latent_num_frames = (num_frames - 1) // TEMPORAL_COMPRESSION_RATIO + 1
# print("num tokens:", latent_num_frames * latent_height * latent_width // 4)
generator = torch.Generator(device=device).manual_seed(seed)
latents = torch.randn(
(batch_size, transformer.config.in_channels, latent_num_frames, latent_height, latent_width),
dtype=dtype,
device=device,
generator=generator,
)
prompt_embeds = prepare_t5_embeddings(text_encoder, tokenizer, prompt, device, dtype, T5_SEQUENCE_LENGTH)
negative_prompt_embeds = prepare_t5_embeddings(text_encoder, tokenizer, negative_prompt, device, dtype, T5_SEQUENCE_LENGTH)
encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = scheduler.timesteps.to(dtype=dtype)
print("Warmup step...")
for _ in range(2):
latent_model_input = torch.cat([latents, latents], dim=0)
rotary_emb = transformer.rope(latent_model_input)
_ = transformer(
hidden_states=latent_model_input,
timestep=timesteps[0].expand(latent_model_input.size(0)).clone(),
encoder_hidden_states=encoder_hidden_states,
rotary_emb=rotary_emb,
)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
context = (
profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True)
if enable_profiling
else contextlib.nullcontext()
)
if not enable_cudagraph:
scheduler.set_begin_index(0)
with context as ctx:
start_event.record()
rotary_emb = None
for i in range(num_inference_steps):
print("step:", i + 1, "/", num_inference_steps)
latent_model_input = torch.cat([latents, latents], dim=0)
if rotary_emb is None:
rotary_emb = transformer.rope(latent_model_input)
v_cond_uncond = transformer(
hidden_states=latent_model_input,
timestep=timesteps[i].expand(latent_model_input.size(0)).clone(),
encoder_hidden_states=encoder_hidden_states,
rotary_emb=rotary_emb,
)
v_cond, v_uncond = v_cond_uncond.chunk(2, dim=0)
v_pred = v_uncond + guidance_scale * (v_cond - v_uncond)
latents = scheduler.step(v_pred, timesteps[i], latents).prev_sample
end_event.record()
torch.cuda.synchronize()
else:
latent_model_input = torch.cat([latents, latents], dim=0)
rotary_emb = transformer.rope(latent_model_input)
graph, static_latents, static_timestep, static_x = capture_cudagraph(
transformer, latent_model_input, encoder_hidden_states, timesteps[0].expand(latent_model_input.size(0)).clone(), rotary_emb
)
scheduler.set_begin_index(0)
with context as ctx:
start_event.record()
static_x.copy_(static_latents)
for i in range(num_inference_steps):
torch._foreach_copy_(
(static_latents, static_timestep),
(static_x, timesteps[i].expand(static_latents.size(0)).clone()),
non_blocking=True,
)
graph.replay()
v_cond, v_uncond = static_x.chunk(2, dim=0)
v_pred = v_uncond + guidance_scale * (v_cond - v_uncond)
latents = scheduler.step(v_pred, timesteps[i], static_latents).prev_sample
latent_model_input = torch.cat([latents, latents], dim=0)
static_x.copy_(latents)
end_event.record()
torch.cuda.synchronize()
latents = static_x
total_time = start_event.elapsed_time(end_event) / 1000.0
if enable_profiling:
print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
ctx.export_chrome_trace("dump_benchmark_wan.json")
if cp_options.mesh is None or cp_options._flattened_mesh.get_local_rank() == 0:
print(f"time: {total_time:.2f}s")
latents_mean = torch.tensor(vae.config.latents_mean, dtype=dtype, device=device).view(1, vae.config.z_dim, 1, 1, 1)
latents_std = 1.0 / torch.tensor(vae.config.latents_std, dtype=dtype, device=device).view(1, vae.config.z_dim, 1, 1, 1)
latents = latents / latents_std + latents_mean
video = vae.decode(latents, return_dict=False)[0]
video = video_processor.postprocess_video(video, output_type="pil")[0]
export_to_video(video, output_file, fps=16)
class EquipartitionSharder:
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
assert tensor.size()[dim] % mesh.size() == 0
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
return tensor
# Reference:
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
def _wait_tensor(tensor):
if isinstance(tensor, funcol.AsyncCollectiveTensor):
tensor = tensor.wait()
return tensor
def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
shape = x.shape
# HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization
# to benchmark triton codegen fails somewhere:
# buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
# ValueError: Tensors must be contiguous
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
x = x.reshape(shape)
x = _wait_tensor(x)
return x
def _templated_ring_attention(query, key, value):
ring_mesh = cp_options.mesh["ring"]
rank = cp_options._ring_local_rank
world_size = cp_options.ring_degree
if world_size == 1:
return cp_options.attention_op(query, key, value)
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
kv_buffer = kv_buffer.chunk(world_size)
for i in range(world_size):
if i > 0:
kv = kv_buffer[next_rank]
key = kv[:key.numel()].reshape_as(key)
value = kv[key.numel():].reshape_as(value)
next_rank = (next_rank + 1) % world_size
out, lse = cp_options.attention_op(query, key, value)
if cp_options.convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse
out = out.to(query.dtype)
lse = lse.squeeze(-1)
return out, lse
def _templated_ulysses_attention(query, key, value, *, return_lse: bool = False):
ulysses_mesh = cp_options.mesh["ulysses"]
world_size = cp_options.ulysses_degree
group = ulysses_mesh.get_group()
if world_size == 1:
return cp_options.attention_op(query, key, value)
B, S_Q_LOCAL, H, D = query.shape
_, S_KV_LOCAL, _, _ = key.shape
H_LOCAL = H // world_size
query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
query, key, value = (
_all_to_all_single(x, group)
for x in (query, key, value)
)
query, key, value = (
x.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
for x in (query, key, value)
)
out, lse = cp_options.attention_op(query, key, value)
out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = _all_to_all_single(out, group)
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
if return_lse:
lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
lse = _all_to_all_single(lse, group)
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
else:
lse = None
return out, lse
# TODO: currently produces incorrect results (for example, with CP=4, ring=2, ulysses=2, half output is our expected image,
# and other half is some completely different)
# def _templated_unified_attention(query, key, value):
# ulysses_mesh = cp_options.mesh["ulysses"]
# ulysses_size = ulysses_mesh.size()
# ulysses_group = ulysses_mesh.get_group()
# B, S_LOCAL, H, D = query.shape
# H_LOCAL = H // ulysses_size
# query, key, value = (
# x.reshape(B, S_LOCAL, ulysses_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
# for x in (query, key, value)
# )
# query, key, value = (
# wait_tensor(funcol.all_to_all_single(x, None, None, group=ulysses_group))
# for x in (query, key, value)
# )
# query, key, value = (
# x.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
# for x in (query, key, value)
# )
# out, lse = _templated_ring_attention(query, key, value)
# out = out.reshape(B, ulysses_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
# lse = lse.reshape(B, ulysses_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
# out = wait_tensor(funcol.all_to_all_single(out, None, None, group=ulysses_group))
# lse = wait_tensor(funcol.all_to_all_single(lse, None, None, group=ulysses_group))
# out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
# lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
# return out, lse
# For fullgraph=True tracing to be compatible
@torch.library.custom_op("flash_attn_3::_flash_attn_forward_original", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_func(query, key, value)
lse = lse.permute(0, 2, 1)
return out, lse
@torch.library.register_fake("flash_attn_3::_flash_attn_forward_original")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape)
@torch.library.custom_op("flash_attn_3::_flash_attn_forward_hf", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_hf(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
out, lse = flash_attn_3_hf.flash_attn_func(query, key, value, causal=False)
return out
@torch.library.register_fake("flash_attn_3::_flash_attn_forward_hf")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.empty_like(query)
def _attention_torch_cudnn(query, key, value):
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=None,
compute_log_sumexp=True,
)
)
out = out.transpose(1, 2).contiguous()
lse = lse.transpose(1, 2).contiguous()
return out, lse
def _attention_flash_attn_2(query, key, value):
out, lse, _ = flash_attn_func(query, key, value, return_attn_probs=True)
lse = lse.permute(0, 2, 1)
return out, lse
def _attention_flash_attn_3_original(query, key, value):
out = _wrapped_flash_attn_3_original(query, key, value)
return out
def _attention_flash_attn_3_hf(query, key, value):
out = _wrapped_flash_attn_3_hf(query, key, value)
return out
def _download_hf_flash_attn_3():
global flash_attn_3_hf
flash_attn_3_hf = get_kernel("kernels-community/flash-attn3")
def get_args():
DEFAULT_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
DEFAULT_PROMPT = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
DEFAULT_NEGATIVE_PROMPT = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=DEFAULT_MODEL_ID)
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT)
parser.add_argument("--negative_prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT)
parser.add_argument("--height", type=int, default=704)
parser.add_argument("--width", type=int, default=1280)
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--num_inference_steps", type=int, default=30)
parser.add_argument("--guidance_scale", type=float, default=5.0)
parser.add_argument(
"--compile_mode",
type=str,
default="none",
choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
)
parser.add_argument("--output_file", type=str, default="output.mp4")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--enable_fp32_rope", action="store_true")
parser.add_argument("--enable_profiling", action="store_true")
parser.add_argument("--enable_cudagraph", action="store_true")
parser.add_argument("--attention_provider", type=str, default="cudnn", choices=["cudnn", "fa2", "fa3", "fa3_original"])
parser.add_argument("--ring_degree", type=int, default=1)
parser.add_argument("--ulysses_degree", type=int, default=1)
parser.add_argument("--disable_tf32", action="store_true")
parser.add_argument("--disable_flags", action="store_true")
parser.add_argument("--seed", type=int, default=31337)
args = parser.parse_args()
return args
def config(args):
torch.manual_seed(args.seed)
if args.enable_fp32_rope:
global ROPE_PRECISION
ROPE_PRECISION = torch.float32
if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]:
raise ValueError(
"Only compiled modes 'none', 'default', and 'max-autotune-no-cudagraphs' are supported with CUDAGraphs."
)
global ATTENTION_OP
if args.attention_provider == "cudnn":
ATTENTION_OP = _attention_torch_cudnn
elif args.attention_provider == "fa2":
ATTENTION_OP = _attention_flash_attn_2
elif args.attention_provider == "fa3":
_download_hf_flash_attn_3()
ATTENTION_OP = _attention_flash_attn_3_hf
elif args.attention_provider == "fa3_original":
ATTENTION_OP = _attention_flash_attn_3_original
else:
assert False
if args.enable_profiling and args.enable_cudagraph:
torch.profiler._utils._init_for_cuda_graphs()
if not args.disable_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if not args.disable_flags:
apply_flags()
def setup_distributed(ring_degree: int, ulysses_degree: int, compile_mode: str):
global ATTENTION_OP
dist.init_process_group("nccl")
torch.cuda.set_device(torch.device("cuda", dist.get_rank()))
if ring_degree * ulysses_degree != dist.get_world_size():
raise ValueError(f"`ring_degree * ulysses_degree` must equal the world size {dist.get_world_size()}.")
mesh_names = ["ring", "ulysses"]
mesh_dims = [ring_degree, ulysses_degree]
mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names)
cp_options.ring_degree = ring_degree
cp_options.ulysses_degree = ulysses_degree
cp_options.mesh = mesh
cp_options.convert_to_fp32 = True
cp_options.attention_op = ATTENTION_OP
cp_options._flattened_mesh = mesh._flatten()
cp_options._ring_mesh = mesh["ring"]
cp_options._ulysses_mesh = mesh["ulysses"]
cp_options._ring_local_rank = cp_options._ring_mesh.get_local_rank()
cp_options._ulysses_local_rank = cp_options._ulysses_mesh.get_local_rank()
if ring_degree > 1 and ulysses_degree > 1:
raise ValueError("The current implementation is incorrect for unified attention and needs to be fixed.")
# cp_options.mode = "unified"
# ATTENTION_OP = _templated_unified_attention
elif ulysses_degree > 1:
cp_options.mode = "ulysses"
ATTENTION_OP = _templated_ulysses_attention
else:
cp_options.mode = "ring"
ATTENTION_OP = _templated_ring_attention
if compile_mode != "none":
torch._dynamo.config.suppress_errors = True
torch._inductor.config.reorder_for_compute_comm_overlap = True
if __name__ == "__main__":
args = get_args()
try:
config(args)
setup_distributed(args.ring_degree, args.ulysses_degree, args.compile_mode)
main(
model_id=args.model_id,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
compile_mode=args.compile_mode,
output_file=args.output_file,
cache_dir=args.cache_dir,
enable_cudagraph=args.enable_cudagraph,
enable_profiling=args.enable_profiling,
seed=args.seed,
)
except Exception as e:
print(f"An error occurred: {e}")
if dist.is_initialized():
torch.distributed.breakpoint()
raise
finally:
if dist.is_initialized():
dist.destroy_process_group()
@a-r-r-o-w
Copy link
Author

H100:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment