Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Created August 25, 2025 12:23
Show Gist options
  • Save gau-nernst/1d89bb2d064763dbaddbe8cca862adf7 to your computer and use it in GitHub Desktop.
Save gau-nernst/1d89bb2d064763dbaddbe8cca862adf7 to your computer and use it in GitHub Desktop.
PyTorch's built-in varlen attention
import torch
from torch import Tensor
def varlen_attn(
query: Tensor,
key: Tensor,
value: Tensor,
cum_seq_q: Tensor,
cum_seq_k: Tensor,
max_q: int,
max_k: int,
dropout_p: float = 0.0,
is_causal: bool = False,
):
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
query,
key,
value,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p,
is_causal,
return_debug_mask=False,
)
return output
if __name__ == "__main__":
import flash_attn
num_heads_q = 4
num_heads_k = 2
DIM = 128
q_list = []
k_list = []
v_list = []
offsets_q = [0]
offsets_k = [0]
max_q = 0
max_k = 0
for _ in range(5):
len_q, len_k = torch.randint(16, 1024, size=(2,)).tolist()
q_list.append(torch.randn(len_q, num_heads_q, DIM, device="cuda", dtype=torch.bfloat16))
k_list.append(torch.randn(len_k, num_heads_k, DIM, device="cuda", dtype=torch.bfloat16))
v_list.append(torch.randn(len_k, num_heads_k, DIM, device="cuda", dtype=torch.bfloat16))
offsets_q.append(offsets_q[-1] + len_q)
offsets_k.append(offsets_k[-1] + len_k)
max_q = max(max_q, len_q)
max_k = max(max_k, len_k)
q = torch.cat(q_list).requires_grad_()
k = torch.cat(k_list).requires_grad_()
v = torch.cat(v_list).requires_grad_()
cu_q = torch.tensor(offsets_q, device="cuda", dtype=torch.int32)
cu_k = torch.tensor(offsets_k, device="cuda", dtype=torch.int32)
q_ref = q.detach().requires_grad_()
k_ref = k.detach().requires_grad_()
v_ref = v.detach().requires_grad_()
out = varlen_attn(q, k, v, cu_q, cu_k, max_q, max_k)
out_ref = flash_attn.flash_attn_varlen_func(q_ref, k_ref, v_ref, cu_q, cu_k, max_q, max_k)
grad = torch.randn_like(out)
out.backward(grad)
out_ref.backward(grad)
@torch.no_grad()
def check(out: Tensor, ref: Tensor):
diff = out.float() - ref.float()
rel_diff = diff.abs() / ref.abs().clip(1e-4)
mean_rel_diff = rel_diff.mean().item()
max_rel_diff = rel_diff.max()
pct = (rel_diff < 1e-6).float().mean()
print(f"{mean_rel_diff=:.2e}, {max_rel_diff:.2e}, {pct * 100:.2f}% elements have relative error<1e-6")
check(out, out_ref) # mean_rel_diff=8.22e-07, 3.00e-01, 99.99% elements have relative error<1e-6
check(q.grad, q_ref.grad) # mean_rel_diff=4.00e-06, 1.46e-01, 99.95% elements have relative error<1e-6
check(k.grad, k_ref.grad) # mean_rel_diff=1.10e-05, 6.10e-01, 99.92% elements have relative error<1e-6
check(v.grad, v_ref.grad) # mean_rel_diff=7.12e-06, 3.33e-01, 99.95% elements have relative error<1e-6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment