Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created May 30, 2025 04:31
Show Gist options
  • Save a-r-r-o-w/10214514289eed81bde988bd508307cc to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/10214514289eed81bde988bd508307cc to your computer and use it in GitHub Desktop.
Copy-pastable implementation for various attention backends
import contextlib
import functools
import inspect
import os
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
FINETRAINERS_ATTN_CHECKS = os.environ.get("FINETRAINERS_ATTN_CHECKS", "0").lower() in ("1", "true", "yes")
FINETRAINERS_ATTN_PROVIDER = os.environ.get("FINETRAINERS_ATTN_PROVIDER", "native").lower()
import torch
# Since we will be patching the `scaled_dot_product_attention` function with `attention_dispatch` to take
# control for dispatching to different attention providers, we need to import the original function
# to be able to use it and not go into infinite recursion when the dispatcher calls `scaled_dot_product_attention`.
import torch.autograd
from diffusers.utils.import_utils import OptionalDependencyNotAvailable
from torch.nn.functional import scaled_dot_product_attention as native_sdpa
from .logger import get_logger
from .utils import (
is_flash_attn_available,
is_flash_attn_version,
is_flash_attn_3_available,
is_sageattention_available,
is_sageattention_version,
is_torch_version,
is_xformers_available,
is_xformers_version,
)
if is_flash_attn_available():
if is_flash_attn_version("<", "2.6.3"):
raise OptionalDependencyNotAvailable(
"The `flash-attn` library version is too old. Please update it to at least 2.6.3."
)
from flash_attn import flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
else:
flash_attn_varlen_func = None
_flash_attn_forward = None
_flash_attn_backward = None
if is_flash_attn_3_available():
from flash_attn_interface import _flash_attn_backward as _flash_attn_3_backward, _flash_attn_forward as _flash_attn_3_forward
else:
_flash_attn_3_forward = None
_flash_attn_3_backward = None
if is_sageattention_available():
if is_sageattention_version("<", "2.1.1"):
raise OptionalDependencyNotAvailable(
"The `sageattention` library version is too old. Please update it to at least 2.1.1."
)
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
else:
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_varlen = None
if is_torch_version(">=", "2.5.0"):
import torch.nn.attention.flex_attention as flex_attention
if is_torch_version(">=", "2.6.0"):
from torch.distributed.tensor.experimental._attention import (
_AttentionOp,
_cp_options,
_templated_ring_attention,
_templated_ring_attention_backward,
set_rotate_method,
)
else:
_cp_options = None
_templated_ring_attention = None
set_rotate_method = None
class _AttentionOp:
def __init__(self, *args, **kwargs):
raise OptionalDependencyNotAvailable(
"The `torch.distributed.tensor.experimental._attention` module is not available. Please update PyTorch to at least 2.6.0."
)
if is_xformers_available():
if is_xformers_version("<", "0.0.29"):
raise OptionalDependencyNotAvailable(
"The `xformers` library version is too old. Please update it to at least 0.0.29."
)
import xformers.ops as xops
else:
xops = None
logger = get_logger()
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
# ===== Custom operator implementations/wrappers =====
def _finetrainers_scaled_dot_product_efficient_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
compute_log_sumexp: bool = False,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Wrapper for https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946
# See: https://github.com/pytorch/pytorch/issues/152942
seqlen_q = query.shape[-2]
out, lse, philox_seed, philox_offset = torch.ops.aten._scaled_dot_product_efficient_attention(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
compute_log_sumexp=compute_log_sumexp,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
# LSE is aligned to the next nearest multiple of 32. This is a workaround to return the lse without alignment so that pytorch
# ring attention does not error out with shape mismatch
if compute_log_sumexp:
assert lse.ndim == 3
lse = lse[:, :, :seqlen_q] # .contiguous()
return out, lse, philox_seed, philox_offset
# aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)
def _finetrainers_scaled_dot_product_efficient_attention_backward(
grad_out_: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
dropout_p: float,
grad_input_mask: List[bool],
is_causal: bool = False,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert len(grad_input_mask) == 4
# https://github.com/pytorch/pytorch/blob/bb9fbb294af385057a72e5b1386cf40f86aadbec/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h#L113
kAlignLSE = 32
logsumexp = torch.nn.functional.pad(
logsumexp, (0, kAlignLSE - (logsumexp.shape[-1] % kAlignLSE)), value=float("inf")
)
# NOTE: cannot pass out as keyword argument because max-autotune fails due to the presence of "out" keyword.
# fails at: torch._dynamo.variables.torch.py::TorchInGraphFunctionVariable.call_function
grad_query, grad_key, grad_value, grad_attn_bias = torch.ops.aten._scaled_dot_product_efficient_attention_backward(
grad_out_,
query,
key,
value,
attn_bias,
out,
logsumexp,
philox_seed=philox_seed,
philox_offset=philox_offset,
dropout_p=dropout_p,
grad_input_mask=grad_input_mask,
is_causal=is_causal,
scale=scale,
)
return grad_query, grad_key, grad_value, grad_attn_bias
# ===== Attention provider =====
class AttentionProvider(str, Enum):
# EAGER = "eager"
# `flash-attn`
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
_FLASH_3 = "_flash_3"
# PyTorch native
FLEX = "flex"
NATIVE = "native"
_NATIVE_CUDNN = "_native_cudnn"
_NATIVE_EFFICIENT = "_native_efficient"
_NATIVE_FLASH = "_native_flash"
_NATIVE_MATH = "_native_math"
# `sageattention`
SAGE = "sage"
SAGE_VARLEN = "sage_varlen"
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
_SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
_SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
# TODO: let's not add support for Sparge Attention now because it requires tuning per model
# We can look into supporting something "autotune"-ing in the future
# SPARGE = "sparge"
# `xformers`
XFORMERS = "xformers"
class _AttentionProviderRegistry:
_providers = {}
_constraints = {}
_supports_cp = {}
_supported_arg_names = {}
_active_provider = AttentionProvider(FINETRAINERS_ATTN_PROVIDER)
_checks_enabled = FINETRAINERS_ATTN_CHECKS
# Context parallel attributes
_mesh: torch.distributed.device_mesh.DeviceMesh = None
_convert_to_fp32: bool = None
_rotate_method: Literal["allgather", "alltoall"] = None
@classmethod
def register(
cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None, supports_cp: bool = False
):
logger.debug(f"Registering attention provider: {provider}")
def decorator(func):
cls._providers[provider] = func
cls._constraints[provider] = constraints or []
cls._supports_cp[provider] = supports_cp
cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys())
return func
return decorator
@classmethod
def get_active_provider(cls):
return cls._active_provider, cls._providers[cls._active_provider]
@classmethod
def list_providers(cls):
return list(cls._providers.keys())
@classmethod
def supports_context_parallel(cls, provider: AttentionProvider):
if provider not in cls._providers:
raise ValueError(f"Provider {provider} is not registered.")
return cls._supports_cp.get(provider, False)
@classmethod
def context_parallel_enabled(cls):
return cls._mesh is not None
@classmethod
def _set_context_parallel(
cls,
mesh: torch.distributed.device_mesh.DeviceMesh = None,
convert_to_fp32: bool = None,
rotate_method: str = None,
*,
reset: bool = False,
):
if reset:
mesh = convert_to_fp32 = rotate_method = None
cls._mesh = mesh
cls._convert_to_fp32 = convert_to_fp32
cls._rotate_method = rotate_method
@classmethod
def _raise_cp_error_if_mesh_not_set(cls):
if cls._mesh is None:
raise ValueError(
"`_AttentionProviderRegistry._mesh` is None. It must be set before calling context parallel attention methods."
)
@contextlib.contextmanager
def attention_provider(
provider: AttentionProvider = AttentionProvider.NATIVE,
*,
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
convert_to_fp32: bool = True,
rotate_method: str = "allgather",
):
"""Context manager to set the active attention provider and possibly enable context parallelism."""
if provider not in _AttentionProviderRegistry._providers:
raise ValueError(f"Provider {provider} is not registered.")
if mesh is not None and not _AttentionProviderRegistry.supports_context_parallel(provider):
raise ValueError(f"Provider {provider} does not support context parallelism.")
old_provider = _AttentionProviderRegistry._active_provider
_AttentionProviderRegistry._active_provider = provider
_AttentionProviderRegistry._mesh = mesh
_AttentionProviderRegistry._convert_to_fp32 = convert_to_fp32
_AttentionProviderRegistry._rotate_method = rotate_method
if mesh is not None:
_convert_to_f32 = _cp_options.convert_to_f32
_enable_load_balance = _cp_options.enable_load_balance
_rotate_method = _cp_options.rotate_method
try:
yield
finally:
_AttentionProviderRegistry._active_provider = old_provider
_AttentionProviderRegistry._mesh = None
_AttentionProviderRegistry._convert_to_fp32 = None
_AttentionProviderRegistry._rotate_method = None
if mesh is not None:
_cp_options.convert_to_f32 = _convert_to_f32
_cp_options.enable_load_balance = _enable_load_balance
_cp_options.rotate_method = _rotate_method
def attention_dispatch(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
attention_kwargs = attention_kwargs or {}
provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider()
kwargs = {
"query": query,
"key": key,
"value": value,
"attn_mask": attn_mask,
"dropout_p": dropout_p,
"is_causal": is_causal,
"scale": scale,
"enable_gqa": enable_gqa,
**attention_kwargs,
}
if _AttentionProviderRegistry._checks_enabled:
removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name])
if removed_kwargs:
log_freq = 512
msg = (
f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}. This "
f"message will be logged every {log_freq} calls."
)
logger.log_freq("WARNING", "REMOVING_ATTN_UNSUPPORTED_KWARGS", msg, log_freq)
for check in _AttentionProviderRegistry._constraints.get(provider_name):
check(**kwargs)
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]}
if _AttentionProviderRegistry.context_parallel_enabled():
_set_context_parallel_options(**kwargs)
return provider_fn(**kwargs)
# ===== Helper functions =====
# @torch.compiler.assume_constant_result
def _set_context_parallel_options(is_causal: bool, **kwargs):
_cp_options.enable_load_balance = is_causal
_cp_options.convert_to_f32 = _AttentionProviderRegistry._convert_to_fp32
set_rotate_method(_AttentionProviderRegistry._rotate_method)
def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None:
if attn_mask is not None:
raise ValueError("Attention mask must be None for this provider.")
def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
if attn_mask is not None and is_causal:
raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
if query.device != key.device or query.device != value.device:
raise ValueError("Query, key, and value must be on the same device.")
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError("Query, key, and value must have the same dtype.")
def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_device(query, key, value)
if query.device.type != "cuda":
raise ValueError("Query, key, and value must be on a CUDA device.")
def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_device_cuda(query, key, value)
if torch.cuda.get_device_capability(query.device) < (major, minor):
raise ValueError(
f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
)
return check_device_cuda
def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
if query.dtype != key.dtype:
raise ValueError("Query and key must have the same dtype.")
if query.dtype != value.dtype:
raise ValueError("Query and value must have the same dtype.")
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
_check_qkv_dtype_match(query, key, value)
if query.dtype not in (torch.bfloat16, torch.float16):
raise ValueError("Query, key, and value must be either bfloat16 or float16.")
def _check_shape(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
if query.shape[-1] != key.shape[-1]:
raise ValueError("Query and key must have the same last dimension.")
if query.shape[-2] != value.shape[-2]:
raise ValueError("Query and value must have the same second to last dimension.")
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
raise ValueError("Attention mask must match the key's second to last dimension.")
def _prepare_for_flash_attn_or_sage_varlen(
batch_size: int,
seq_len_q: int,
seq_len_kv: int,
attn_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> None:
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
if attn_mask is None:
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
else:
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
"""
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in
FlashAttention/Sage varlen.
Supports 1D to 4D shapes and common broadcasting patterns.
"""
if attn_mask.dtype != torch.bool:
raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
if attn_mask.ndim == 1:
# [seq_len_k] -> broadcast across batch
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
elif attn_mask.ndim == 2:
# [batch_size, seq_len_k]. Maybe broadcast across batch
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
)
attn_mask = attn_mask.expand(batch_size, seq_len_k)
elif attn_mask.ndim == 3:
# [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
)
attn_mask = attn_mask.any(dim=1)
attn_mask = attn_mask.expand(batch_size, seq_len_k)
elif attn_mask.ndim == 4:
# [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
if attn_mask.size(0) not in [1, batch_size]:
raise ValueError(
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
)
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
else:
raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
if attn_mask.shape != (batch_size, seq_len_k):
raise ValueError(
f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
)
return attn_mask
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx
# ===== Attention provider implementations =====
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
class _flash_attn_flash_attention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_softmax: bool = False,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
q, k, v = (x.permute(0, 2, 1, 3) for x in (q, k, v))
out, lse, S_dmask, rng_state = _flash_attn_forward(
q=q,
k=k,
v=v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
ctx.save_for_backward(q, k, v, out, lse, rng_state)
out = out.permute(0, 2, 1, 3)
return (out, lse) if return_softmax else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
q, k, v, out, lse, rng_state = ctx.saved_tensors
grad_out = grad_out.permute(0, 2, 1, 3).contiguous() # [B, N, S, D] -> [B, S, N, D]
grad_query, grad_key, grad_value = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = grad_out.size(3)
grad_out_padded = grad_out
if head_size_og % 8 != 0:
grad_out_padded = torch.nn.functional.pad(grad_out, [0, 8 - head_size_og % 8])
# NOTE: cannot pass out as keyword argument because max-autotune fails due to the presence of "out" keyword.
# fails at: torch._dynamo.variables.torch.py::TorchInGraphFunctionVariable.call_function
_flash_attn_backward(
grad_out_padded,
q,
k,
v,
out,
lse,
grad_query,
grad_key,
grad_value,
dropout_p=ctx.dropout_p,
softmax_scale=ctx.softmax_scale,
causal=ctx.causal,
window_size_left=ctx.window_size[0],
window_size_right=ctx.window_size[1],
softcap=ctx.softcap,
alibi_slopes=ctx.alibi_slopes,
deterministic=ctx.deterministic,
rng_state=rng_state,
)
# Head dimension could have been padded
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
grad_query, grad_key, grad_value = (x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
@_AttentionProviderRegistry.register(
AttentionProvider.FLASH,
constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_cp=False,
)
def flash_attn_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
dispatch_fn = _flash_attn_flash_attention
return dispatch_fn.apply(
query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, return_lse
)
class _flash_attn_flash_attention_3(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size: Tuple[int, int] = (-1, -1),
attention_chunk=0,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa=None,
deterministic: bool = False,
sm_margin: int = 0,
return_softmax: bool = False,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.attention_chunk = attention_chunk
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin
q, k, v = (x.permute(0, 2, 1, 3) for x in (q, k, v))
out, lse, *rest = _flash_attn_3_forward(
q,
k,
v,
None,
None,
qv,
None,
cu_seqlens_q=None,
cu_seqlens_k=None,
cu_seqlens_k_new=None,
seqused_q=None,
seqused_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
page_table=None,
kv_batch_idx=None,
leftpad_k=None,
rotary_cos=None,
rotary_sin=None,
seqlens_rotary=None,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
attention_chunk=attention_chunk,
softcap=softcap,
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
)
ctx.save_for_backward(q, k, v, out, lse)
out = out.permute(0, 2, 1, 3)
return (out, lse) if return_softmax else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
q, k, v, out, lse = ctx.saved_tensors
grad_out = grad_out.permute(0, 2, 1, 3).contiguous() # [B, N, S, D] -> [B, S, N, D]
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
grad_query, grad_key, grad_value = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
# NOTE: cannot pass out as keyword argument because max-autotune fails due to the presence of "out" keyword.
# fails at: torch._dynamo.variables.torch.py::TorchInGraphFunctionVariable.call_function
_flash_attn_3_backward(
grad_out,
q,
k,
v,
out,
lse,
cu_seqlens_q=None,
cu_seqlens_k=None,
sequed_q=None,
sequed_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
dq=grad_query,
dk=grad_key,
dv=grad_value,
softmax_scale=ctx.softmax_scale,
causal=ctx.causal,
window_size=ctx.window_size,
softcap=ctx.softcap,
deterministic=ctx.deterministic,
sm_margin=ctx.sm_margin,
)
# Head dimension could have been padded
grad_query = grad_query[..., : q.shape[-1]]
grad_key = grad_key[..., : k.shape[-1]]
grad_value = grad_value[..., : v.shape[-1]]
grad_query, grad_key, grad_value = (x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None, None, None, None, None, None
@_AttentionProviderRegistry.register(
AttentionProvider._FLASH_3,
constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_cp=False,
)
def flash_attn_flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
dispatch_fn = _flash_attn_flash_attention_3
return dispatch_fn.apply(
query, key, value, scale, is_causal, None, None, None, None, window_size, 0, softcap, 1, None, deterministic, 0, return_lse
)
@_AttentionProviderRegistry.register(
AttentionProvider.FLASH_VARLEN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_cp=False,
)
def _flash_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, _, seq_len_q, _ = query.shape
_, _, seq_len_kv, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
if _AttentionProviderRegistry.context_parallel_enabled():
return_attn_probs = True
out = flash_attn_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)
rest = None
if return_attn_probs:
out, *rest = out
out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous()
if return_attn_probs:
return out, *rest[:1]
return out
@_AttentionProviderRegistry.register(
AttentionProvider.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
supports_cp=False,
)
def _native_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
kernel_options: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# TODO: should we LRU cache the block mask creation?
score_mod = None
block_mask = None
batch_size, num_heads, seq_len_q, _ = query.shape
_, _, seq_len_kv, _ = key.shape
if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
block_mask = attn_mask
elif is_causal:
block_mask = flex_attention.create_block_mask(
_flex_attention_causal_mask_mod, None, None, seq_len_q, seq_len_kv, query.device
)
elif torch.is_tensor(attn_mask):
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
if attn_mask.dtype == torch.bool:
# TODO: this probably does not work but verify!
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
block_mask = flex_attention.create_block_mask(
mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
)
else:
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
else:
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
return flex_attention.flex_attention(
query=query,
key=key,
value=value,
score_mod=score_mod,
block_mask=block_mask,
scale=scale,
enable_gqa=enable_gqa,
return_lse=return_lse,
kernel_options=None,
)
@_AttentionProviderRegistry.register(
AttentionProvider.NATIVE,
constraints=[_check_device, _check_shape],
supports_cp=False,
)
def _native_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
return native_sdpa(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
class _native_cudnn_attention(torch.autograd.Function):
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
# forward declaration:
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
# backward declaration:
# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
):
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.attn_mask = attn_mask
# TODO(aryan): investigate why this requires contiguous to work with max-autotune
query, key, value = (x.contiguous() for x in (query, key, value))
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=attn_mask,
compute_log_sumexp=True,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=False,
scale=scale,
)
)
out = out.contiguous()
ctx.max_q = max_q
ctx.max_k = max_k
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
# NOTE: cannot pass out as keyword argument because max-autotune fails due to the presence of "out" keyword.
# fails at: torch._dynamo.variables.torch.py::TorchInGraphFunctionVariable.call_function
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
philox_seed=philox_seed,
philox_offset=philox_offset,
attn_bias=ctx.attn_mask,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=ctx.max_q,
max_k=ctx.max_k,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
)
return grad_query, grad_key, grad_value, None, None, None, None, None
class _native_ring_native_cudnn_attention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
):
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.attn_mask = attn_mask
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
_templated_ring_attention(
mesh=_AttentionProviderRegistry._mesh,
seq_dim=2,
op=torch.ops.aten._scaled_dot_product_cudnn_attention,
query=query,
key=key,
value=value,
attn_bias=attn_mask,
compute_log_sumexp=True,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=False,
scale=scale,
)
)
ctx.max_q = max_q
ctx.max_k = max_k
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
grad_query, grad_key, grad_value = _templated_ring_attention_backward(
mesh=_AttentionProviderRegistry._mesh,
seq_dim=2,
op=torch.ops.aten._scaled_dot_product_cudnn_attention_backward,
grad_out=grad_out,
grad_out_name="grad_out",
query=query,
key=key,
value=value,
out=out,
logsumexp=lse,
philox_seed=philox_seed,
philox_offset=philox_offset,
attn_bias=ctx.attn_mask,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=ctx.max_q,
max_k=ctx.max_k,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
)
return grad_query, grad_key, grad_value, None, None, None, None, None
@_AttentionProviderRegistry.register(
AttentionProvider._NATIVE_CUDNN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_cp=True,
)
def native_cudnn_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
dispatch_fn = (
_native_ring_native_cudnn_attention
if _AttentionProviderRegistry.context_parallel_enabled()
else _native_cudnn_attention
)
return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, return_lse)
class _native_efficient_attention(torch.autograd.Function):
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946
# forward declaration:
# aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
# backward declaration:
# aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
):
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.attn_mask = attn_mask
# NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details.
out, lse, philox_seed, philox_offset = _finetrainers_scaled_dot_product_efficient_attention_forward(
query=query,
key=key,
value=value,
attn_bias=attn_mask,
compute_log_sumexp=True,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors
# NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details.
grad_query, grad_key, grad_value, grad_attn_bias = (
_finetrainers_scaled_dot_product_efficient_attention_backward(
grad_out_=grad_out,
query=query,
key=key,
value=value,
attn_bias=ctx.attn_mask,
out=out,
logsumexp=lse,
philox_seed=philox_seed,
philox_offset=philox_offset,
dropout_p=ctx.dropout_p,
grad_input_mask=[True, True, True, False],
is_causal=ctx.is_causal,
scale=ctx.scale,
)
)
return grad_query, grad_key, grad_value, None, None, None, None, None
class _native_ring_native_efficient_attention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
):
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.attn_mask = attn_mask
# NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details.
out, lse, philox_seed, philox_offset = _templated_ring_attention(
mesh=_AttentionProviderRegistry._mesh,
seq_dim=2,
op=_finetrainers_scaled_dot_product_efficient_attention_forward,
query=query,
key=key,
value=value,
attn_bias=attn_mask,
compute_log_sumexp=True,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors
# NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details.
grad_query, grad_key, grad_value, grad_attn_bias = _templated_ring_attention_backward(
mesh=_AttentionProviderRegistry._mesh,
seq_dim=2,
op=_finetrainers_scaled_dot_product_efficient_attention_backward,
grad_out=grad_out,
grad_out_name="grad_out_",
query=query,
key=key,
value=value,
attn_bias=ctx.attn_mask,
out=out,
logsumexp=lse,
philox_seed=philox_seed,
philox_offset=philox_offset,
dropout_p=ctx.dropout_p,
grad_input_mask=[True, True, True, False],
is_causal=ctx.is_causal,
scale=ctx.scale,
)
return grad_query, grad_key, grad_value, None, None, None, None, None
@_AttentionProviderRegistry.register(
AttentionProvider._NATIVE_EFFICIENT,
constraints=[_check_device, _check_shape],
supports_cp=True,
)
def native_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
dispatch_fn = (
_native_ring_native_efficient_attention
if _AttentionProviderRegistry.context_parallel_enabled()
else _native_efficient_attention
)
return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale)
class _native_flash_attention(torch.autograd.Function):
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14910
# forward declaration:
# aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
# backward declaration:
# aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
):
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_flash_attention(
query=query,
key=key,
value=value,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=False,
scale=scale,
)
)
ctx.max_q = max_q
ctx.max_k = max_k
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
# NOTE: cannot pass out as keyword argument because max-autotune fails due to the presence of "out" keyword.
# fails at: torch._dynamo.variables.torch.py::TorchInGraphFunctionVariable.call_function
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=ctx.max_q,
max_k=ctx.max_k,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
philox_seed=philox_seed,
philox_offset=philox_offset,
scale=ctx.scale,
)
return grad_query, grad_key, grad_value, None, None, None, None
class _native_ring_native_flash_attention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
):
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
_templated_ring_attention(
mesh=_AttentionProviderRegistry._mesh,
seq_dim=2,
op=torch.ops.aten._scaled_dot_product_flash_attention,
query=query,
key=key,
value=value,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
)
ctx.max_q = max_q
ctx.max_k = max_k
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args: torch.Tensor,
):
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
grad_query, grad_key, grad_value, *_ = _templated_ring_attention_backward(
mesh=_AttentionProviderRegistry._mesh,
seq_dim=2,
op=torch.ops.aten._scaled_dot_product_flash_attention_backward,
grad_out=grad_out,
grad_out_name="grad_out",
query=query,
key=key,
value=value,
out=out,
logsumexp=lse,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=ctx.max_q,
max_k=ctx.max_k,
philox_seed=philox_seed,
philox_offset=philox_offset,
)
return grad_query, grad_key, grad_value, None, None, None, None
@_AttentionProviderRegistry.register(
AttentionProvider._NATIVE_FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_cp=True,
)
def native_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
dispatch_fn = (
_native_ring_native_flash_attention
if _AttentionProviderRegistry.context_parallel_enabled()
else _native_flash_attention
)
return dispatch_fn.apply(query, key, value, dropout_p, is_causal, scale, return_lse)
# class _native_math_attention(torch.autograd.Function):
# # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14901
# # forward declaration:
# # aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0., bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
# # backward declaration:
# # does not exist
# @staticmethod
# def forward(
# ctx: torch.autograd.function.FunctionCtx,
# query: torch.Tensor,
# key: torch.Tensor,
# value: torch.Tensor,
# attn_mask: Optional[torch.Tensor] = None,
# dropout_p: float = 0.0,
# is_causal: bool = False,
# dropout_mask: Optional[torch.Tensor] = None,
# scale: Optional[float] = None,
# enable_gqa: bool = False,
# return_scores: bool = False,
# ):
# ctx.dropout_p = dropout_p
# ctx.is_causal = is_causal
# ctx.scale = scale
# ctx.enable_gqa = enable_gqa
# print(f"query.shape: {query.shape}")
# with torch.enable_grad():
# out, scores = torch.ops.aten._scaled_dot_product_attention_math(
# query=query,
# key=key,
# value=value,
# attn_mask=attn_mask,
# dropout_p=dropout_p,
# is_causal=is_causal,
# dropout_mask=dropout_mask,
# scale=scale,
# enable_gqa=enable_gqa,
# )
# ctx.save_for_backward(query, key, value, out)
# return (out, scores) if return_scores else out
# @staticmethod
# def backward(
# ctx: torch.autograd.function.FunctionCtx,
# grad_out: torch.Tensor,
# ):
# raise NotImplementedError("Backward pass for native math attention is not implemented.")
@_AttentionProviderRegistry.register(
AttentionProvider._NATIVE_MATH,
constraints=[_check_device, _check_shape],
supports_cp=False,
)
def native_math_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
return native_sdpa(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
@_AttentionProviderRegistry.register(
AttentionProvider.SAGE,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_cp=False,
)
def _sage_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
if _AttentionProviderRegistry.context_parallel_enabled():
return_lse = True
kwargs = {
"q": query,
"k": key,
"v": value,
"tensor_layout": "HND",
"is_causal": is_causal,
"sm_scale": scale,
"return_lse": return_lse,
}
out = sageattn(**kwargs)
rest = None
if return_lse:
out, *rest = out
if return_lse:
return out, *rest[:1]
return out
@_AttentionProviderRegistry.register(
AttentionProvider.SAGE_VARLEN,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _sage_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
is_causal: bool = False,
scale: Optional[float] = None,
smooth_k: bool = True,
attn_mask: Optional[torch.Tensor] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
batch_size, _, seq_len_q, _ = query.shape
_, _, seq_len_kv, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if enable_gqa:
# TODO
pass
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out = sageattn_varlen(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scale,
smooth_k=smooth_k,
)
out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous()
return out
@_AttentionProviderRegistry.register(
AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA,
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
supports_cp=False,
)
def _sage_qk_int8_pv_fp8_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse,
)
@_AttentionProviderRegistry.register(
AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
supports_cp=False,
)
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
return_lse=return_lse,
)
@_AttentionProviderRegistry.register(
AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA,
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
supports_cp=False,
)
def _sage_qk_int8_pv_fp16_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse,
)
@_AttentionProviderRegistry.register(
AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON,
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
supports_cp=False,
)
def _sage_qk_int8_pv_fp16_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
smooth_k: bool = True,
return_lse: bool = False,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
v=value,
tensor_layout="HND",
quantization_backend=quantization_backend,
is_causal=is_causal,
sm_scale=scale,
smooth_k=smooth_k,
return_lse=return_lse,
)
@_AttentionProviderRegistry.register(
AttentionProvider.XFORMERS,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
)
def _xformers_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
batch_size, num_heads_q, seq_len_q, _ = query.shape
_, num_heads_kv, seq_len_kv, _ = key.shape
# TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns
if is_causal:
attn_mask = xops.LowerTriangularMask()
elif attn_mask is not None:
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
elif attn_mask.ndim != 4:
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
# QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers
# query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
if enable_gqa:
if num_heads_q % num_heads_kv != 0:
raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
num_heads_per_group = num_heads_q // num_heads_kv
query = query.unflatten(2, (num_heads_kv, -1))
key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
if enable_gqa:
out = out.flatten(2, 3)
out = out.permute(0, 2, 1, 3) # .contiguous()
return out
@a-r-r-o-w
Copy link
Author

Usage

import torch
torch.nn.functional.scaled_dot_product_attention = attention_dispatch

with attention_provider("flash_varlen"):
    model(...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment