Skip to content

Instantly share code, notes, and snippets.

@zxgx
Last active September 19, 2024 19:32
Show Gist options
  • Save zxgx/2874bc7056dd2ec1b8d7ab5570deddd0 to your computer and use it in GitHub Desktop.
Save zxgx/2874bc7056dd2ec1b8d7ab5570deddd0 to your computer and use it in GitHub Desktop.
qkvpadded flash attn varlen vs padded, v2.6.3
import time
import numpy as np
import torch
from itertools import accumulate
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func
device = torch.device('cuda:0')
dtype = torch.bfloat16
num_head = 16
head_dim = 72
DEFAULT_AR_MAP = {
"144p": (144, 256),
"240p": (240, 426),
"360p": (360, 640),
"480p": (480, 854),
"720p": (720, 1280),
"1080p": (1080, 1920),
"2k": (1440, 2560),
}
len_map = len(DEFAULT_AR_MAP)
keys = list(DEFAULT_AR_MAP.keys())
for i in range(len_map):
for j in range(i+1, len_map):
for nf1 in [1, 4, 8, 16]:
for nf2 in [1, 4, 8, 16]:
ar1 = DEFAULT_AR_MAP[keys[i]]
ar2 = DEFAULT_AR_MAP[keys[j]]
seq_len = [
ar1[0]//16*ar1[1]//16, # 360p
ar2[0]//16*ar2[1]//16 # 1080p
]
bs = [nf1, nf2]
batch_lens = []
for _len, _bs in zip(seq_len, bs):
batch_lens += [_len] * _bs
max_seqlen = max(batch_lens)
batch_lens = list(accumulate(batch_lens, initial=0))
seqlens = torch.tensor(batch_lens, dtype=torch.int32, device=device)
qkv_base = torch.empty((seqlens[-1], 3, num_head, head_dim), dtype=dtype, device=device)
grad_base = torch.empty((seqlens[-1], num_head, head_dim), dtype=dtype, device=device)
fwd_times, bwd_times = [], []
# warmup
for _ in range(2):
torch.rand(qkv_base.size(), device=device, dtype=dtype, out=qkv_base)
qkv = qkv_base.detach().requires_grad_(True)
torch.rand(grad_base.size(), device=device, dtype=dtype, out=grad_base)
grad = grad_base.detach()
assert qkv.requires_grad
out = flash_attn_varlen_qkvpacked_func(qkv, seqlens, max_seqlen)
out.backward(grad)
for _ in range(10):
torch.rand(qkv_base.size(), device=device, dtype=dtype, out=qkv_base)
qkv = qkv_base.detach().requires_grad_(True)
torch.rand(grad_base.size(), device=device, dtype=dtype, out=grad_base)
grad = grad_base.detach()
torch.cuda.synchronize()
start = time.time()
out = flash_attn_varlen_qkvpacked_func(qkv, seqlens, max_seqlen)
torch.cuda.synchronize()
fwd_times.append(time.time() - start)
start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_times.append(time.time() - start)
print(f"({keys[i]}*{nf1}, {keys[j]}*{nf2}) varlen:\n - fwd: {np.mean(fwd_times):.4f} ~ {np.std(fwd_times):.4f}\n"
f" - bwd: {np.mean(bwd_times):.4f} ~ {np.std(bwd_times):.4f}")
qkv1_base = torch.empty((bs[0], seq_len[0], 3, num_head, head_dim), dtype=dtype, device=device)
qkv2_base = torch.empty((bs[1], seq_len[1], 3, num_head, head_dim), dtype=dtype, device=device)
grad1_base = torch.empty((bs[0], seq_len[0], num_head, head_dim), dtype=dtype, device=device)
grad2_base = torch.empty((bs[1], seq_len[1], num_head, head_dim), dtype=dtype, device=device)
fwd_times, bwd_times = [], []
for _ in range(2):
torch.rand(qkv1_base.size(), device=device, dtype=dtype, out=qkv1_base)
torch.rand(qkv2_base.size(), device=device, dtype=dtype, out=qkv2_base)
torch.rand(grad1_base.size(), device=device, dtype=dtype, out=grad1_base)
torch.rand(grad2_base.size(), device=device, dtype=dtype, out=grad2_base)
qkv1 = qkv1_base.detach().requires_grad_(True)
qkv2 = qkv2_base.detach().requires_grad_(True)
grad1 = grad1_base.detach()
grad2 = grad2_base.detach()
out = flash_attn_qkvpacked_func(qkv1)
out.backward(grad1)
out = flash_attn_qkvpacked_func(qkv2)
out.backward(grad2)
for _ in range(10):
torch.rand(qkv1_base.size(), device=device, dtype=dtype, out=qkv1_base)
torch.rand(qkv2_base.size(), device=device, dtype=dtype, out=qkv2_base)
torch.rand(grad1_base.size(), device=device, dtype=dtype, out=grad1_base)
torch.rand(grad2_base.size(), device=device, dtype=dtype, out=grad2_base)
qkv1 = qkv1_base.detach().requires_grad_(True)
qkv2 = qkv2_base.detach().requires_grad_(True)
grad1 = grad1_base.detach()
grad2 = grad2_base.detach()
torch.cuda.synchronize()
start = time.time()
out = flash_attn_qkvpacked_func(qkv1)
torch.cuda.synchronize()
fwd_time = (time.time() - start)
start = time.time()
out.backward(grad1)
torch.cuda.synchronize()
bwd_time = (time.time() - start)
torch.cuda.synchronize()
start = time.time()
out = flash_attn_qkvpacked_func(qkv2)
torch.cuda.synchronize()
fwd_time += (time.time() - start)
start = time.time()
out.backward(grad2)
torch.cuda.synchronize()
bwd_time += (time.time() - start)
fwd_times.append(fwd_time)
bwd_times.append(bwd_time)
print(f"({keys[i]}*{nf1}, {keys[j]}*{nf2}) vanilla:\n - fwd: {np.mean(fwd_times):.4f} ~ {np.std(fwd_times):.4f}\n"
f" - bwd: {np.mean(bwd_times):.4f} ~ {np.std(bwd_times):.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment