Skip to content

Instantly share code, notes, and snippets.

@huseinzol05
Last active February 6, 2025 06:23
Show Gist options
  • Save huseinzol05/1c7ec8326ead76126154a54022d1242a to your computer and use it in GitHub Desktop.
Save huseinzol05/1c7ec8326ead76126154a54022d1242a to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
import torch
head_num = 16
dim = 128
seq_len = 100
chunk_size = 5
batch_size = 1
q = torch.randn(batch_size, head_num, seq_len, dim, requires_grad=True).cuda().to(torch.bfloat16)
k = torch.randn(batch_size, head_num, seq_len, dim, requires_grad=True).cuda().to(torch.bfloat16)
v = torch.randn(batch_size, head_num, seq_len, dim, requires_grad=True).cuda().to(torch.bfloat16)
g0 = torch.randn(batch_size, head_num, seq_len, dim, dtype=torch.bfloat16, device='cuda')
g1 = torch.randn(batch_size, head_num, seq_len, dtype=torch.bfloat16, device='cuda')
q.retain_grad()
k.retain_grad()
v.retain_grad()
actual_out, actual_lse = flex_attention(q, k, v, block_mask = None, return_lse=True)
(actual_out.grad_fn.saved_tensors[3] == actual_out).float().mean()
(actual_out.grad_fn.saved_tensors[4] == actual_lse).float().mean()
@huseinzol05
Copy link
Author

huseinzol05 commented Feb 6, 2025

Screenshot 2025-02-06 at 2 22 02 PM

actual_out.grad_fn.saved_tensors[4] != actual_lse

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment