Created
August 29, 2025 12:57
-
-
Save a-r-r-o-w/a4a9d889857d1b36a52fadfdc1aca249 to your computer and use it in GitHub Desktop.
Wan 2.2 5B T2V benchmarks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/bash | |
| set -xe | |
| export TORCH_LOGS="recompiles,inductor" | |
| export CUDA_VISIBLE_DEVICES="3,2,1,0" | |
| set_fa_op() { | |
| COMPUTE_CAPABILITY=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | tr -d '.') | |
| if [[ ! "$COMPUTE_CAPABILITY" =~ ^[0-9]+$ ]]; then | |
| echo "Error: Could not determine GPU compute capability" | |
| exit 1 | |
| fi | |
| if [[ "$COMPUTE_CAPABILITY" -ge 90 ]]; then | |
| export FA_OP="fa3_original" | |
| elif [[ "$COMPUTE_CAPABILITY" -ge 80 ]]; then | |
| export FA_OP="fa2" | |
| else | |
| echo "Error: GPU compute capability $COMPUTE_CAPABILITY is below SM80, unsupported" | |
| exit 1 | |
| fi | |
| } | |
| set_fa_op | |
| # NOTE: for context parallel, we do not use cudagraphs due to OOM issues. further investigation is needed | |
| # MODEL_ID="Wan-AI/Wan2.1-T2V-1.3B-Diffusers" | |
| # MODEL_ID="Wan-AI/Wan2.1-T2V-14B-Diffusers" | |
| MODEL_ID="Wan-AI/Wan2.2-TI2V-5B-Diffusers" | |
| PROMPT="A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." | |
| NEGATIVE_PROMPT="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" | |
| HEIGHT=704 | |
| WIDTH=1280 | |
| NUM_INFERENCE_STEPS=30 | |
| GUIDANCE_SCALE=5.0 | |
| OUTPUT_FILE="output.mp4" | |
| # Eager baseline (single GPU) | |
| torchrun --nnodes 1 --nproc_per_node 1 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --output_file "${OUTPUT_FILE}" | |
| # Cudagraph (single GPU) | |
| torchrun --nnodes 1 --nproc_per_node 1 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --enable_cudagraph \ | |
| --output_file "${OUTPUT_FILE}" | |
| # + context parallel (ring=2) | |
| torchrun --nnodes 1 --nproc_per_node 2 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --attention_provider "${FA_OP}" \ | |
| --ring_degree 2 \ | |
| --output_file "${OUTPUT_FILE}" | |
| # + context parallel (ulysses=2) | |
| torchrun --nnodes 1 --nproc_per_node 2 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --attention_provider "${FA_OP}" \ | |
| --ulysses_degree 2 \ | |
| --output_file "${OUTPUT_FILE}" | |
| # + context parallel (ring=4) | |
| torchrun --nnodes 1 --nproc_per_node 4 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --attention_provider "${FA_OP}" \ | |
| --ring_degree 4 \ | |
| --output_file "${OUTPUT_FILE}" | |
| # + context parallel (ulysses=4) | |
| torchrun --nnodes 1 --nproc_per_node 4 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --attention_provider "${FA_OP}" \ | |
| --ulysses_degree 4 \ | |
| --output_file "${OUTPUT_FILE}" | |
| # Cudagraph + compile (single GPU) | |
| torchrun --nnodes 1 --nproc_per_node 1 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --compile_mode default \ | |
| --enable_cudagraph \ | |
| --output_file "${OUTPUT_FILE}" | |
| # + context parallel + compile (ring=2) | |
| torchrun --nnodes 1 --nproc_per_node 2 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --attention_provider "${FA_OP}" \ | |
| --ring_degree 2 \ | |
| --compile_mode default \ | |
| --output_file "${OUTPUT_FILE}" | |
| # + context parallel + compile (ulysses=2) | |
| torchrun --nnodes 1 --nproc_per_node 2 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --attention_provider "${FA_OP}" \ | |
| --ulysses_degree 2 \ | |
| --compile_mode default \ | |
| --output_file "${OUTPUT_FILE}" | |
| # + context parallel + compile (ring=4) | |
| torchrun --nnodes 1 --nproc_per_node 4 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --attention_provider "${FA_OP}" \ | |
| --ring_degree 4 \ | |
| --compile_mode default \ | |
| --output_file "${OUTPUT_FILE}" | |
| # + context parallel + compile (ulysses=4) | |
| torchrun --nnodes 1 --nproc_per_node 4 dump_wan_5b.py \ | |
| --model_id "${MODEL_ID}" \ | |
| --prompt "${PROMPT}" \ | |
| --height "${HEIGHT}" \ | |
| --width "${WIDTH}" \ | |
| --num_inference_steps "${NUM_INFERENCE_STEPS}" \ | |
| --guidance_scale "${GUIDANCE_SCALE}" \ | |
| --attention_provider "${FA_OP}" \ | |
| --ulysses_degree 4 \ | |
| --compile_mode default \ | |
| --output_file "${OUTPUT_FILE}" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import contextlib | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Callable, Literal, Optional, Tuple | |
| import torch | |
| import torch.distributed as dist | |
| import torch.distributed._functional_collectives as funcol | |
| import torch.profiler._utils | |
| import torch._dynamo.config | |
| import torch._inductor.config | |
| import torch._higher_order_ops.auto_functionalize as af | |
| from torch.profiler import profile, record_function, ProfilerActivity | |
| from diffusers import AutoencoderKLWan, UniPCMultistepScheduler | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin | |
| from diffusers.models.attention import FeedForward | |
| from diffusers.models.cache_utils import CacheMixin | |
| from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.models.normalization import FP32LayerNorm | |
| from diffusers.video_processor import VideoProcessor | |
| from diffusers.utils import export_to_video | |
| from kernels import get_kernel | |
| from transformers import T5EncoderModel, T5TokenizerFast | |
| try: | |
| from flash_attn import flash_attn_func | |
| except: | |
| print("Flash Attention 2 not found.") | |
| try: | |
| from flash_attn_interface import flash_attn_func as flash_attn_3_func | |
| except: | |
| print("Flash Attention 3 not found.") | |
| ATTENTION_OP = None | |
| @dataclass | |
| class ContextParallelOptions: | |
| ring_degree: int = None | |
| ulysses_degree: int = None | |
| mode: Literal["ring", "ulysses", "unified"] = "ring" | |
| mesh: dist.DeviceMesh | None = None | |
| convert_to_fp32: bool = True | |
| attention_op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]] | None = None | |
| _flattened_mesh: dist.DeviceMesh = None | |
| _ring_mesh: dist.DeviceMesh = None | |
| _ulysses_mesh: dist.DeviceMesh = None | |
| _ring_local_rank: int = None | |
| _ulysses_local_rank: int = None | |
| cp_options = ContextParallelOptions() | |
| def apply_flags(): | |
| torch._dynamo.config.inline_inbuilt_nn_modules = False | |
| torch._dynamo.config.cache_size_limit = 128 | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.disable_progress = False | |
| torch._inductor.config.fx_graph_cache = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.aggressive_fusion = True | |
| # torch._inductor.config.reorder_for_compute_comm_overlap = True | |
| torch._inductor.config.shape_padding = True | |
| torch._inductor.config.triton.enable_persistent_tma_matmul = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| af.auto_functionalized_v2._cacheable = True | |
| af.auto_functionalized._cacheable = True | |
| class WanAttention(torch.nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| heads: int = 8, | |
| dim_head: int = 64, | |
| eps: float = 1e-5, | |
| dropout: float = 0.0, | |
| added_kv_proj_dim: Optional[int] = None, | |
| is_cross_attention: bool = False, | |
| ): | |
| super().__init__() | |
| self.inner_dim = dim_head * heads | |
| self.heads = heads | |
| self.added_kv_proj_dim = added_kv_proj_dim | |
| self.is_cross_attention = is_cross_attention | |
| self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) | |
| self.to_k = torch.nn.Linear(dim, self.inner_dim, bias=True) | |
| self.to_v = torch.nn.Linear(dim, self.inner_dim, bias=True) | |
| self.to_out = torch.nn.ModuleList([ | |
| torch.nn.Linear(self.inner_dim, self.inner_dim, bias=True), | |
| torch.nn.Dropout(dropout) | |
| ]) | |
| self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) | |
| self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) | |
| self.add_k_proj = self.add_v_proj = None | |
| if added_kv_proj_dim is not None: | |
| self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) | |
| self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) | |
| self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) | |
| self.fused_projections = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor], | |
| rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| encoder_hidden_states_img = None | |
| if self.add_k_proj is not None: | |
| image_context_length = encoder_hidden_states.shape[1] - 512 | |
| encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] | |
| encoder_hidden_states = encoder_hidden_states[:, image_context_length:] | |
| query, key, value = self._get_qkv_projections(hidden_states, encoder_hidden_states) | |
| query = self.norm_q(query) | |
| key = self.norm_k(key) | |
| query, key, value = (x.unflatten(2, (self.heads, -1)) for x in (query, key, value)) | |
| if rotary_emb is not None: | |
| def apply_rotary_emb(hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor): | |
| x = hidden_states.unflatten(-1, (-1, 2)) | |
| x1, x2 = x[..., 0], x[..., 1] | |
| cos = freqs_cos[..., 0::2] | |
| sin = freqs_sin[..., 1::2] | |
| out = torch.empty_like(hidden_states) | |
| out[..., 0::2] = x1 * cos - x2 * sin | |
| out[..., 1::2] = x1 * sin + x2 * cos | |
| return out.type_as(hidden_states) | |
| query = apply_rotary_emb(query, rotary_emb[0], rotary_emb[1]) | |
| key = apply_rotary_emb(key, rotary_emb[0], rotary_emb[1]) | |
| hidden_states_img = None | |
| if encoder_hidden_states_img is not None: | |
| key_img, value_img = self._get_added_kv_projections(encoder_hidden_states_img) | |
| key_img = self.norm_added_k(key_img) | |
| key_img, value_img = (x.unflatten(2, (self.heads, -1)) for x in (key_img, value_img)) | |
| hidden_states_img, _ = ATTENTION_OP(query, key_img, value_img) | |
| hidden_states_img = hidden_states_img.flatten(2, 3).type_as(query) | |
| hidden_states, _ = ATTENTION_OP(query, key, value) | |
| hidden_states = hidden_states.flatten(2, 3).type_as(query) | |
| if hidden_states_img is not None: | |
| hidden_states = hidden_states + hidden_states_img | |
| hidden_states = self.to_out[0](hidden_states) | |
| return hidden_states | |
| @torch.no_grad() | |
| def fuse_projections(self): | |
| if not self.is_cross_attention: | |
| concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) | |
| concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) | |
| out_features, in_features = concatenated_weights.shape | |
| with torch.device("meta"): | |
| self.to_qkv = torch.nn.Linear(in_features, out_features, bias=True) | |
| self.to_qkv.load_state_dict( | |
| {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True | |
| ) | |
| else: | |
| concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) | |
| concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) | |
| out_features, in_features = concatenated_weights.shape | |
| with torch.device("meta"): | |
| self.to_kv = torch.nn.Linear(in_features, out_features, bias=True) | |
| self.to_kv.load_state_dict( | |
| {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True | |
| ) | |
| if self.added_kv_proj_dim is not None: | |
| concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) | |
| concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) | |
| out_features, in_features = concatenated_weights.shape | |
| with torch.device("meta"): | |
| self.to_added_kv = torch.nn.Linear(in_features, out_features, bias=True) | |
| self.to_added_kv.load_state_dict( | |
| {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True | |
| ) | |
| self.fused_projections = True | |
| def _get_qkv_projections(self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor]): | |
| if self.fused_projections: | |
| if not self.is_cross_attention: | |
| query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1) | |
| else: | |
| query = self.to_q(hidden_states) | |
| key, value = self.to_kv(encoder_hidden_states).chunk(2, dim=-1) | |
| else: | |
| query = self.to_q(hidden_states) | |
| if self.added_kv_proj_dim is None: | |
| key = self.to_k(hidden_states) | |
| value = self.to_v(hidden_states) | |
| else: | |
| key = self.to_k(encoder_hidden_states) | |
| value = self.to_v(encoder_hidden_states) | |
| return query, key, value | |
| def _get_added_kv_projections(self, encoder_hidden_states_img: torch.Tensor): | |
| if self.fused_projections: | |
| key_img, value_img = self.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) | |
| else: | |
| key_img = self.add_k_proj(encoder_hidden_states_img) | |
| value_img = self.add_v_proj(encoder_hidden_states_img) | |
| return key_img, value_img | |
| class WanImageEmbedding(torch.nn.Module): | |
| def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): | |
| super().__init__() | |
| self.norm1 = FP32LayerNorm(in_features) | |
| self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") | |
| self.norm2 = FP32LayerNorm(out_features) | |
| if pos_embed_seq_len is not None: | |
| self.pos_embed = torch.nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) | |
| else: | |
| self.pos_embed = None | |
| def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: | |
| if self.pos_embed is not None: | |
| batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape | |
| encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) | |
| encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed | |
| hidden_states = self.norm1(encoder_hidden_states_image) | |
| hidden_states = self.ff(hidden_states) | |
| hidden_states = self.norm2(hidden_states) | |
| return hidden_states | |
| class WanTimeTextImageEmbedding(torch.nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| time_freq_dim: int, | |
| time_proj_dim: int, | |
| text_embed_dim: int, | |
| image_embed_dim: Optional[int] = None, | |
| pos_embed_seq_len: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) | |
| self.act_fn = torch.nn.SiLU() | |
| self.time_proj = torch.nn.Linear(dim, time_proj_dim) | |
| self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") | |
| self.image_embedder = None | |
| if image_embed_dim is not None: | |
| self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) | |
| def forward( | |
| self, | |
| timestep: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_hidden_states_image: Optional[torch.Tensor] = None, | |
| ): | |
| timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states) | |
| temb = self.time_embedder(timestep) | |
| timestep_proj = self.time_proj(self.act_fn(temb)) | |
| encoder_hidden_states = self.text_embedder(encoder_hidden_states) | |
| if encoder_hidden_states_image is not None: | |
| encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) | |
| return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image | |
| class WanRotaryPosEmbed(torch.nn.Module): | |
| def __init__( | |
| self, | |
| attention_head_dim: int, | |
| patch_size: Tuple[int, int, int], | |
| max_seq_len: int, | |
| theta: float = 10000.0, | |
| ): | |
| super().__init__() | |
| self.attention_head_dim = attention_head_dim | |
| self.patch_size = patch_size | |
| self.max_seq_len = max_seq_len | |
| h_dim = w_dim = 2 * (attention_head_dim // 6) | |
| t_dim = attention_head_dim - h_dim - w_dim | |
| freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 | |
| freqs_cos = [] | |
| freqs_sin = [] | |
| for dim in [t_dim, h_dim, w_dim]: | |
| freq_cos, freq_sin = get_1d_rotary_pos_embed( | |
| dim, | |
| max_seq_len, | |
| theta, | |
| use_real=True, | |
| repeat_interleave_real=True, | |
| freqs_dtype=freqs_dtype, | |
| ) | |
| freqs_cos.append(freq_cos) | |
| freqs_sin.append(freq_sin) | |
| self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) | |
| self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| batch_size, num_channels, num_frames, height, width = hidden_states.shape | |
| p_t, p_h, p_w = self.patch_size | |
| ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w | |
| split_sizes = [ | |
| self.attention_head_dim - 2 * (self.attention_head_dim // 3), | |
| self.attention_head_dim // 3, | |
| self.attention_head_dim // 3, | |
| ] | |
| freqs_cos = self.freqs_cos.split(split_sizes, dim=1) | |
| freqs_sin = self.freqs_sin.split(split_sizes, dim=1) | |
| freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) | |
| freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) | |
| freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) | |
| freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) | |
| freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) | |
| freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) | |
| freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) | |
| freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) | |
| if cp_options.mesh is not None: | |
| freqs_cos = EquipartitionSharder.shard(freqs_cos, dim=1, mesh=cp_options._flattened_mesh) | |
| freqs_sin = EquipartitionSharder.shard(freqs_sin, dim=1, mesh=cp_options._flattened_mesh) | |
| return freqs_cos, freqs_sin | |
| class WanTransformerBlock(torch.nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| ffn_dim: int, | |
| num_heads: int, | |
| qk_norm: str = "rms_norm_across_heads", | |
| cross_attn_norm: bool = False, | |
| eps: float = 1e-6, | |
| added_kv_proj_dim: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) | |
| self.attn1 = WanAttention( | |
| dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, is_cross_attention=False, | |
| ) | |
| self.attn2 = WanAttention( | |
| dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, added_kv_proj_dim=added_kv_proj_dim, is_cross_attention=True | |
| ) | |
| self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else torch.nn.Identity() | |
| self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") | |
| self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) | |
| self.scale_shift_table = torch.nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| rotary_emb: torch.Tensor, | |
| ) -> torch.Tensor: | |
| shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( | |
| self.scale_shift_table + temb.float() | |
| ).chunk(6, dim=1) | |
| scale_msa.add_(1) | |
| c_scale_msa.add_(1) | |
| norm_hidden_states = torch.addcmul(shift_msa, self.norm1(hidden_states.float()), scale_msa).type_as(hidden_states) | |
| attn_output = self.attn1(norm_hidden_states, None, rotary_emb) | |
| hidden_states = torch.addcmul(hidden_states.float(), attn_output, gate_msa).type_as(hidden_states) | |
| norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) | |
| attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None) | |
| hidden_states.add_(attn_output) | |
| norm_hidden_states = torch.addcmul(c_shift_msa, self.norm3(hidden_states.float()), c_scale_msa).type_as(hidden_states) | |
| ff_output = self.ffn(norm_hidden_states) | |
| hidden_states = torch.addcmul(hidden_states.float(), ff_output, c_gate_msa).type_as(hidden_states) | |
| return hidden_states | |
| class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): | |
| _keys_to_ignore_on_load_unexpected = ["norm_added_q"] | |
| @register_to_config | |
| def __init__( | |
| self, | |
| patch_size: Tuple[int] = (1, 2, 2), | |
| num_attention_heads: int = 40, | |
| attention_head_dim: int = 128, | |
| in_channels: int = 16, | |
| out_channels: int = 16, | |
| text_dim: int = 4096, | |
| freq_dim: int = 256, | |
| ffn_dim: int = 13824, | |
| num_layers: int = 40, | |
| cross_attn_norm: bool = True, | |
| qk_norm: Optional[str] = "rms_norm_across_heads", | |
| eps: float = 1e-6, | |
| image_dim: Optional[int] = None, | |
| added_kv_proj_dim: Optional[int] = None, | |
| rope_max_seq_len: int = 1024, | |
| pos_embed_seq_len: Optional[int] = None, | |
| ) -> None: | |
| super().__init__() | |
| inner_dim = num_attention_heads * attention_head_dim | |
| out_channels = out_channels or in_channels | |
| self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) | |
| self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) | |
| self.condition_embedder = WanTimeTextImageEmbedding( | |
| dim=inner_dim, | |
| time_freq_dim=freq_dim, | |
| time_proj_dim=inner_dim * 6, | |
| text_embed_dim=text_dim, | |
| image_embed_dim=image_dim, | |
| pos_embed_seq_len=pos_embed_seq_len, | |
| ) | |
| self.blocks = torch.nn.ModuleList( | |
| [ | |
| WanTransformerBlock( | |
| inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) | |
| self.proj_out = torch.nn.Linear(inner_dim, out_channels * math.prod(patch_size)) | |
| self.scale_shift_table = torch.nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| timestep: torch.LongTensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_hidden_states_image: Optional[torch.Tensor] = None, | |
| rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| batch_size, num_channels, num_frames, height, width = hidden_states.shape | |
| p_t, p_h, p_w = self.config.patch_size | |
| post_patch_num_frames = num_frames // p_t | |
| post_patch_height = height // p_h | |
| post_patch_width = width // p_w | |
| hidden_states = self.patch_embedding(hidden_states) | |
| hidden_states = hidden_states.flatten(2).transpose(1, 2) | |
| temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( | |
| timestep, encoder_hidden_states, encoder_hidden_states_image | |
| ) | |
| timestep_proj = timestep_proj.unflatten(1, (6, -1)) | |
| if cp_options.mesh is not None: | |
| hidden_states = EquipartitionSharder.shard(hidden_states, dim=1, mesh=cp_options._flattened_mesh) | |
| encoder_hidden_states = EquipartitionSharder.shard(encoder_hidden_states, dim=1, mesh=cp_options._flattened_mesh) | |
| if encoder_hidden_states_image is not None: | |
| if cp_options.mesh is not None: | |
| encoder_hidden_states_image = EquipartitionSharder.shard(encoder_hidden_states_image, dim=1, mesh=cp_options._flattened_mesh) | |
| encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) | |
| for block in self.blocks: | |
| hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) | |
| shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) | |
| scale.add_(1) | |
| hidden_states = torch.addcmul(shift, self.norm_out(hidden_states.float()), scale).type_as(hidden_states) | |
| hidden_states = self.proj_out(hidden_states) | |
| if cp_options.mesh is not None: | |
| hidden_states = EquipartitionSharder.unshard(hidden_states, dim=1, mesh=cp_options._flattened_mesh) | |
| hidden_states = hidden_states.reshape( | |
| batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 | |
| ) | |
| hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) | |
| output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) | |
| return output | |
| def fuse_qkv_(model: WanTransformer3DModel) -> None: | |
| for submodule in model.modules(): | |
| if not isinstance(submodule, WanAttention): | |
| continue | |
| submodule.fuse_projections() | |
| def prepare_t5_embeddings( | |
| text_encoder: T5EncoderModel, | |
| tokenizer: T5TokenizerFast, | |
| prompt: str, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| max_length: int = 512, | |
| ) -> torch.Tensor: | |
| prompt = [prompt] | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask | |
| seq_lens = mask.gt(0).sum(dim=1).long() | |
| prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state | |
| prompt_embeds = prompt_embeds.to(dtype) | |
| prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] | |
| prompt_embeds = torch.stack( | |
| [torch.cat([u, u.new_zeros(max_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 | |
| ) | |
| return prompt_embeds | |
| @torch.inference_mode() | |
| def capture_cudagraph( | |
| model: WanTransformer3DModel, | |
| latents: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| timestep: torch.Tensor, | |
| rotary_emb: Tuple[torch.Tensor, torch.Tensor], | |
| ): | |
| print("Warming up CUDAGraph capture") | |
| s = torch.cuda.Stream() | |
| s.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(s): | |
| for _ in range(2): | |
| _ = model( | |
| hidden_states=latents, | |
| encoder_hidden_states=encoder_hidden_states, | |
| timestep=timestep, | |
| ) | |
| torch.cuda.current_stream().wait_stream(s) | |
| print("Capturing CUDAGraph") | |
| static_latents = latents.clone() | |
| static_timestep = timestep.clone() | |
| graph = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(graph): | |
| static_x = model( | |
| hidden_states=static_latents, | |
| encoder_hidden_states=encoder_hidden_states, | |
| timestep=static_timestep, | |
| rotary_emb=rotary_emb, | |
| ) | |
| return graph, static_latents, static_timestep, static_x | |
| @torch.inference_mode() | |
| def main( | |
| model_id: str, | |
| prompt: str, | |
| negative_prompt: str, | |
| height: int, | |
| width: int, | |
| num_frames: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| compile_mode: str, | |
| output_file: str, | |
| cache_dir: Optional[str], | |
| enable_cudagraph: bool, | |
| enable_profiling: bool, | |
| seed: int, | |
| ): | |
| # Constants/defaults and device/dtype preparation | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| SPATIAL_COMPRESSION_RATIO = 16 | |
| TEMPORAL_COMPRESSION_RATIO = 4 | |
| T5_SEQUENCE_LENGTH = 512 | |
| # Load the model components | |
| transformer = WanTransformer3DModel.from_pretrained( | |
| model_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=dtype | |
| ) | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| model_id, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=dtype | |
| ) | |
| tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) | |
| vae = AutoencoderKLWan.from_pretrained( | |
| model_id, subfolder="vae", cache_dir=cache_dir, torch_dtype=dtype | |
| ) | |
| scheduler = UniPCMultistepScheduler.from_pretrained(model_id, subfolder="scheduler", cache_dir=cache_dir) | |
| video_processor = VideoProcessor(vae_scale_factor=SPATIAL_COMPRESSION_RATIO) | |
| fuse_qkv_(transformer) | |
| [x.to(device, dtype=dtype) for x in (transformer, text_encoder, vae)] | |
| if compile_mode != "none": | |
| transformer = torch.compile(transformer, mode=compile_mode, fullgraph=True, dynamic=True) | |
| # text_encoder = torch.compile(text_encoder, mode="default", fullgraph=True, dynamic=True) | |
| # We don't compile the VAE due to the implementation calling into non-traceable code paths | |
| # vae.decode = torch.compile(vae.decode, mode="default", fullgraph=True, dynamic=True) | |
| scheduler.set_timesteps = torch.compile(scheduler.set_timesteps, fullgraph=True, dynamic=False) | |
| # Fails due to CPU tensor usage somewhere in the scheduler.step implementation, which Triton does not support | |
| # scheduler.step = torch.compile(scheduler.step, fullgraph=True, dynamic=False) | |
| # Latent, text and timestep conditioning preparation | |
| batch_size = 1 | |
| latent_height = height // SPATIAL_COMPRESSION_RATIO | |
| latent_width = width // SPATIAL_COMPRESSION_RATIO | |
| latent_num_frames = (num_frames - 1) // TEMPORAL_COMPRESSION_RATIO + 1 | |
| # print("num tokens:", latent_num_frames * latent_height * latent_width // 4) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| latents = torch.randn( | |
| (batch_size, transformer.config.in_channels, latent_num_frames, latent_height, latent_width), | |
| dtype=dtype, | |
| device=device, | |
| generator=generator, | |
| ) | |
| prompt_embeds = prepare_t5_embeddings(text_encoder, tokenizer, prompt, device, dtype, T5_SEQUENCE_LENGTH) | |
| negative_prompt_embeds = prepare_t5_embeddings(text_encoder, tokenizer, negative_prompt, device, dtype, T5_SEQUENCE_LENGTH) | |
| encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) | |
| scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = scheduler.timesteps.to(dtype=dtype) | |
| print("Warmup step...") | |
| for _ in range(2): | |
| latent_model_input = torch.cat([latents, latents], dim=0) | |
| rotary_emb = transformer.rope(latent_model_input) | |
| _ = transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timesteps[0].expand(latent_model_input.size(0)).clone(), | |
| encoder_hidden_states=encoder_hidden_states, | |
| rotary_emb=rotary_emb, | |
| ) | |
| torch.cuda.synchronize() | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| context = ( | |
| profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_flops=True) | |
| if enable_profiling | |
| else contextlib.nullcontext() | |
| ) | |
| if not enable_cudagraph: | |
| scheduler.set_begin_index(0) | |
| with context as ctx: | |
| start_event.record() | |
| rotary_emb = None | |
| for i in range(num_inference_steps): | |
| print("step:", i + 1, "/", num_inference_steps) | |
| latent_model_input = torch.cat([latents, latents], dim=0) | |
| if rotary_emb is None: | |
| rotary_emb = transformer.rope(latent_model_input) | |
| v_cond_uncond = transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timesteps[i].expand(latent_model_input.size(0)).clone(), | |
| encoder_hidden_states=encoder_hidden_states, | |
| rotary_emb=rotary_emb, | |
| ) | |
| v_cond, v_uncond = v_cond_uncond.chunk(2, dim=0) | |
| v_pred = v_uncond + guidance_scale * (v_cond - v_uncond) | |
| latents = scheduler.step(v_pred, timesteps[i], latents).prev_sample | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| else: | |
| latent_model_input = torch.cat([latents, latents], dim=0) | |
| rotary_emb = transformer.rope(latent_model_input) | |
| graph, static_latents, static_timestep, static_x = capture_cudagraph( | |
| transformer, latent_model_input, encoder_hidden_states, timesteps[0].expand(latent_model_input.size(0)).clone(), rotary_emb | |
| ) | |
| scheduler.set_begin_index(0) | |
| with context as ctx: | |
| start_event.record() | |
| static_x.copy_(static_latents) | |
| for i in range(num_inference_steps): | |
| torch._foreach_copy_( | |
| (static_latents, static_timestep), | |
| (static_x, timesteps[i].expand(static_latents.size(0)).clone()), | |
| non_blocking=True, | |
| ) | |
| graph.replay() | |
| v_cond, v_uncond = static_x.chunk(2, dim=0) | |
| v_pred = v_uncond + guidance_scale * (v_cond - v_uncond) | |
| latents = scheduler.step(v_pred, timesteps[i], static_latents).prev_sample | |
| latent_model_input = torch.cat([latents, latents], dim=0) | |
| static_x.copy_(latents) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| latents = static_x | |
| total_time = start_event.elapsed_time(end_event) / 1000.0 | |
| if enable_profiling: | |
| print(ctx.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) | |
| ctx.export_chrome_trace("dump_benchmark_wan.json") | |
| if cp_options.mesh is None or cp_options._flattened_mesh.get_local_rank() == 0: | |
| print(f"time: {total_time:.2f}s") | |
| latents_mean = torch.tensor(vae.config.latents_mean, dtype=dtype, device=device).view(1, vae.config.z_dim, 1, 1, 1) | |
| latents_std = 1.0 / torch.tensor(vae.config.latents_std, dtype=dtype, device=device).view(1, vae.config.z_dim, 1, 1, 1) | |
| latents = latents / latents_std + latents_mean | |
| video = vae.decode(latents, return_dict=False)[0] | |
| video = video_processor.postprocess_video(video, output_type="pil")[0] | |
| export_to_video(video, output_file, fps=16) | |
| class EquipartitionSharder: | |
| @classmethod | |
| def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| assert tensor.size()[dim] % mesh.size() == 0 | |
| # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) | |
| # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] | |
| return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())] | |
| @classmethod | |
| def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| tensor = tensor.contiguous() | |
| tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group()) | |
| return tensor | |
| # Reference: | |
| # - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827 | |
| # - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246 | |
| # For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method): | |
| def _wait_tensor(tensor): | |
| if isinstance(tensor, funcol.AsyncCollectiveTensor): | |
| tensor = tensor.wait() | |
| return tensor | |
| def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: | |
| shape = x.shape | |
| # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization | |
| # to benchmark triton codegen fails somewhere: | |
| # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3') | |
| # ValueError: Tensors must be contiguous | |
| x = x.flatten() | |
| x = funcol.all_to_all_single(x, None, None, group) | |
| x = x.reshape(shape) | |
| x = _wait_tensor(x) | |
| return x | |
| def _templated_ring_attention(query, key, value): | |
| ring_mesh = cp_options.mesh["ring"] | |
| rank = cp_options._ring_local_rank | |
| world_size = cp_options.ring_degree | |
| if world_size == 1: | |
| return cp_options.attention_op(query, key, value) | |
| next_rank = (rank + 1) % world_size | |
| prev_out = prev_lse = None | |
| kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() | |
| kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) | |
| kv_buffer = kv_buffer.chunk(world_size) | |
| for i in range(world_size): | |
| if i > 0: | |
| kv = kv_buffer[next_rank] | |
| key = kv[:key.numel()].reshape_as(key) | |
| value = kv[key.numel():].reshape_as(value) | |
| next_rank = (next_rank + 1) % world_size | |
| out, lse = cp_options.attention_op(query, key, value) | |
| if cp_options.convert_to_fp32: | |
| out = out.to(torch.float32) | |
| lse = lse.to(torch.float32) | |
| lse = lse.unsqueeze(-1) | |
| if prev_out is not None: | |
| out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) | |
| lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) | |
| prev_out = out | |
| prev_lse = lse | |
| out = out.to(query.dtype) | |
| lse = lse.squeeze(-1) | |
| return out, lse | |
| def _templated_ulysses_attention(query, key, value, *, return_lse: bool = False): | |
| ulysses_mesh = cp_options.mesh["ulysses"] | |
| world_size = cp_options.ulysses_degree | |
| group = ulysses_mesh.get_group() | |
| if world_size == 1: | |
| return cp_options.attention_op(query, key, value) | |
| B, S_Q_LOCAL, H, D = query.shape | |
| _, S_KV_LOCAL, _, _ = key.shape | |
| H_LOCAL = H // world_size | |
| query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
| key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
| value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
| query, key, value = ( | |
| _all_to_all_single(x, group) | |
| for x in (query, key, value) | |
| ) | |
| query, key, value = ( | |
| x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() | |
| for x in (query, key, value) | |
| ) | |
| out, lse = cp_options.attention_op(query, key, value) | |
| out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() | |
| out = _all_to_all_single(out, group) | |
| out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() | |
| if return_lse: | |
| lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() | |
| lse = _all_to_all_single(lse, group) | |
| lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() | |
| else: | |
| lse = None | |
| return out, lse | |
| # TODO: currently produces incorrect results (for example, with CP=4, ring=2, ulysses=2, half output is our expected image, | |
| # and other half is some completely different) | |
| # def _templated_unified_attention(query, key, value): | |
| # ulysses_mesh = cp_options.mesh["ulysses"] | |
| # ulysses_size = ulysses_mesh.size() | |
| # ulysses_group = ulysses_mesh.get_group() | |
| # B, S_LOCAL, H, D = query.shape | |
| # H_LOCAL = H // ulysses_size | |
| # query, key, value = ( | |
| # x.reshape(B, S_LOCAL, ulysses_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
| # for x in (query, key, value) | |
| # ) | |
| # query, key, value = ( | |
| # wait_tensor(funcol.all_to_all_single(x, None, None, group=ulysses_group)) | |
| # for x in (query, key, value) | |
| # ) | |
| # query, key, value = ( | |
| # x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() | |
| # for x in (query, key, value) | |
| # ) | |
| # out, lse = _templated_ring_attention(query, key, value) | |
| # out = out.reshape(B, ulysses_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() | |
| # lse = lse.reshape(B, ulysses_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() | |
| # out = wait_tensor(funcol.all_to_all_single(out, None, None, group=ulysses_group)) | |
| # lse = wait_tensor(funcol.all_to_all_single(lse, None, None, group=ulysses_group)) | |
| # out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() | |
| # lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() | |
| # return out, lse | |
| # For fullgraph=True tracing to be compatible | |
| @torch.library.custom_op("flash_attn_3::_flash_attn_forward_original", mutates_args=(), device_types="cuda") | |
| def _wrapped_flash_attn_3_original(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| out, lse = flash_attn_3_func(query, key, value) | |
| lse = lse.permute(0, 2, 1) | |
| return out, lse | |
| @torch.library.register_fake("flash_attn_3::_flash_attn_forward_original") | |
| def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch_size, seq_len, num_heads, head_dim = query.shape | |
| lse_shape = (batch_size, seq_len, num_heads) | |
| return torch.empty_like(query), query.new_empty(lse_shape) | |
| @torch.library.custom_op("flash_attn_3::_flash_attn_forward_hf", mutates_args=(), device_types="cuda") | |
| def _wrapped_flash_attn_3_hf(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
| out, lse = flash_attn_3_hf.flash_attn_func(query, key, value, causal=False) | |
| return out | |
| @torch.library.register_fake("flash_attn_3::_flash_attn_forward_hf") | |
| def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
| return torch.empty_like(query) | |
| def _attention_torch_cudnn(query, key, value): | |
| query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) | |
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
| torch.ops.aten._scaled_dot_product_cudnn_attention( | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=None, | |
| compute_log_sumexp=True, | |
| ) | |
| ) | |
| out = out.transpose(1, 2).contiguous() | |
| lse = lse.transpose(1, 2).contiguous() | |
| return out, lse | |
| def _attention_flash_attn_2(query, key, value): | |
| out, lse, _ = flash_attn_func(query, key, value, return_attn_probs=True) | |
| lse = lse.permute(0, 2, 1) | |
| return out, lse | |
| def _attention_flash_attn_3_original(query, key, value): | |
| out = _wrapped_flash_attn_3_original(query, key, value) | |
| return out | |
| def _attention_flash_attn_3_hf(query, key, value): | |
| out = _wrapped_flash_attn_3_hf(query, key, value) | |
| return out | |
| def _download_hf_flash_attn_3(): | |
| global flash_attn_3_hf | |
| flash_attn_3_hf = get_kernel("kernels-community/flash-attn3") | |
| def get_args(): | |
| DEFAULT_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" | |
| DEFAULT_PROMPT = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." | |
| DEFAULT_NEGATIVE_PROMPT = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_id", type=str, default=DEFAULT_MODEL_ID) | |
| parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT) | |
| parser.add_argument("--negative_prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT) | |
| parser.add_argument("--height", type=int, default=704) | |
| parser.add_argument("--width", type=int, default=1280) | |
| parser.add_argument("--num_frames", type=int, default=121) | |
| parser.add_argument("--num_inference_steps", type=int, default=30) | |
| parser.add_argument("--guidance_scale", type=float, default=5.0) | |
| parser.add_argument( | |
| "--compile_mode", | |
| type=str, | |
| default="none", | |
| choices=["none", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], | |
| ) | |
| parser.add_argument("--output_file", type=str, default="output.mp4") | |
| parser.add_argument("--cache_dir", type=str, default=None) | |
| parser.add_argument("--enable_fp32_rope", action="store_true") | |
| parser.add_argument("--enable_profiling", action="store_true") | |
| parser.add_argument("--enable_cudagraph", action="store_true") | |
| parser.add_argument("--attention_provider", type=str, default="cudnn", choices=["cudnn", "fa2", "fa3", "fa3_original"]) | |
| parser.add_argument("--ring_degree", type=int, default=1) | |
| parser.add_argument("--ulysses_degree", type=int, default=1) | |
| parser.add_argument("--disable_tf32", action="store_true") | |
| parser.add_argument("--disable_flags", action="store_true") | |
| parser.add_argument("--seed", type=int, default=31337) | |
| args = parser.parse_args() | |
| return args | |
| def config(args): | |
| torch.manual_seed(args.seed) | |
| if args.enable_fp32_rope: | |
| global ROPE_PRECISION | |
| ROPE_PRECISION = torch.float32 | |
| if args.enable_cudagraph and args.compile_mode not in ["none", "default", "max-autotune-no-cudagraphs"]: | |
| raise ValueError( | |
| "Only compiled modes 'none', 'default', and 'max-autotune-no-cudagraphs' are supported with CUDAGraphs." | |
| ) | |
| global ATTENTION_OP | |
| if args.attention_provider == "cudnn": | |
| ATTENTION_OP = _attention_torch_cudnn | |
| elif args.attention_provider == "fa2": | |
| ATTENTION_OP = _attention_flash_attn_2 | |
| elif args.attention_provider == "fa3": | |
| _download_hf_flash_attn_3() | |
| ATTENTION_OP = _attention_flash_attn_3_hf | |
| elif args.attention_provider == "fa3_original": | |
| ATTENTION_OP = _attention_flash_attn_3_original | |
| else: | |
| assert False | |
| if args.enable_profiling and args.enable_cudagraph: | |
| torch.profiler._utils._init_for_cuda_graphs() | |
| if not args.disable_tf32: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| if not args.disable_flags: | |
| apply_flags() | |
| def setup_distributed(ring_degree: int, ulysses_degree: int, compile_mode: str): | |
| global ATTENTION_OP | |
| dist.init_process_group("nccl") | |
| torch.cuda.set_device(torch.device("cuda", dist.get_rank())) | |
| if ring_degree * ulysses_degree != dist.get_world_size(): | |
| raise ValueError(f"`ring_degree * ulysses_degree` must equal the world size {dist.get_world_size()}.") | |
| mesh_names = ["ring", "ulysses"] | |
| mesh_dims = [ring_degree, ulysses_degree] | |
| mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names) | |
| cp_options.ring_degree = ring_degree | |
| cp_options.ulysses_degree = ulysses_degree | |
| cp_options.mesh = mesh | |
| cp_options.convert_to_fp32 = True | |
| cp_options.attention_op = ATTENTION_OP | |
| cp_options._flattened_mesh = mesh._flatten() | |
| cp_options._ring_mesh = mesh["ring"] | |
| cp_options._ulysses_mesh = mesh["ulysses"] | |
| cp_options._ring_local_rank = cp_options._ring_mesh.get_local_rank() | |
| cp_options._ulysses_local_rank = cp_options._ulysses_mesh.get_local_rank() | |
| if ring_degree > 1 and ulysses_degree > 1: | |
| raise ValueError("The current implementation is incorrect for unified attention and needs to be fixed.") | |
| # cp_options.mode = "unified" | |
| # ATTENTION_OP = _templated_unified_attention | |
| elif ulysses_degree > 1: | |
| cp_options.mode = "ulysses" | |
| ATTENTION_OP = _templated_ulysses_attention | |
| else: | |
| cp_options.mode = "ring" | |
| ATTENTION_OP = _templated_ring_attention | |
| if compile_mode != "none": | |
| torch._dynamo.config.suppress_errors = True | |
| torch._inductor.config.reorder_for_compute_comm_overlap = True | |
| if __name__ == "__main__": | |
| args = get_args() | |
| try: | |
| config(args) | |
| setup_distributed(args.ring_degree, args.ulysses_degree, args.compile_mode) | |
| main( | |
| model_id=args.model_id, | |
| prompt=args.prompt, | |
| negative_prompt=args.negative_prompt, | |
| height=args.height, | |
| width=args.width, | |
| num_frames=args.num_frames, | |
| num_inference_steps=args.num_inference_steps, | |
| guidance_scale=args.guidance_scale, | |
| compile_mode=args.compile_mode, | |
| output_file=args.output_file, | |
| cache_dir=args.cache_dir, | |
| enable_cudagraph=args.enable_cudagraph, | |
| enable_profiling=args.enable_profiling, | |
| seed=args.seed, | |
| ) | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| if dist.is_initialized(): | |
| torch.distributed.breakpoint() | |
| raise | |
| finally: | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
H100: