Last active
September 19, 2024 19:32
-
-
Save zxgx/2874bc7056dd2ec1b8d7ab5570deddd0 to your computer and use it in GitHub Desktop.
qkvpadded flash attn varlen vs padded, v2.6.3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import torch | |
from itertools import accumulate | |
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func | |
device = torch.device('cuda:0') | |
dtype = torch.bfloat16 | |
num_head = 16 | |
head_dim = 72 | |
DEFAULT_AR_MAP = { | |
"144p": (144, 256), | |
"240p": (240, 426), | |
"360p": (360, 640), | |
"480p": (480, 854), | |
"720p": (720, 1280), | |
"1080p": (1080, 1920), | |
"2k": (1440, 2560), | |
} | |
len_map = len(DEFAULT_AR_MAP) | |
keys = list(DEFAULT_AR_MAP.keys()) | |
for i in range(len_map): | |
for j in range(i+1, len_map): | |
for nf1 in [1, 4, 8, 16]: | |
for nf2 in [1, 4, 8, 16]: | |
ar1 = DEFAULT_AR_MAP[keys[i]] | |
ar2 = DEFAULT_AR_MAP[keys[j]] | |
seq_len = [ | |
ar1[0]//16*ar1[1]//16, # 360p | |
ar2[0]//16*ar2[1]//16 # 1080p | |
] | |
bs = [nf1, nf2] | |
batch_lens = [] | |
for _len, _bs in zip(seq_len, bs): | |
batch_lens += [_len] * _bs | |
max_seqlen = max(batch_lens) | |
batch_lens = list(accumulate(batch_lens, initial=0)) | |
seqlens = torch.tensor(batch_lens, dtype=torch.int32, device=device) | |
qkv_base = torch.empty((seqlens[-1], 3, num_head, head_dim), dtype=dtype, device=device) | |
grad_base = torch.empty((seqlens[-1], num_head, head_dim), dtype=dtype, device=device) | |
fwd_times, bwd_times = [], [] | |
# warmup | |
for _ in range(2): | |
torch.rand(qkv_base.size(), device=device, dtype=dtype, out=qkv_base) | |
qkv = qkv_base.detach().requires_grad_(True) | |
torch.rand(grad_base.size(), device=device, dtype=dtype, out=grad_base) | |
grad = grad_base.detach() | |
assert qkv.requires_grad | |
out = flash_attn_varlen_qkvpacked_func(qkv, seqlens, max_seqlen) | |
out.backward(grad) | |
for _ in range(10): | |
torch.rand(qkv_base.size(), device=device, dtype=dtype, out=qkv_base) | |
qkv = qkv_base.detach().requires_grad_(True) | |
torch.rand(grad_base.size(), device=device, dtype=dtype, out=grad_base) | |
grad = grad_base.detach() | |
torch.cuda.synchronize() | |
start = time.time() | |
out = flash_attn_varlen_qkvpacked_func(qkv, seqlens, max_seqlen) | |
torch.cuda.synchronize() | |
fwd_times.append(time.time() - start) | |
start = time.time() | |
out.backward(grad) | |
torch.cuda.synchronize() | |
bwd_times.append(time.time() - start) | |
print(f"({keys[i]}*{nf1}, {keys[j]}*{nf2}) varlen:\n - fwd: {np.mean(fwd_times):.4f} ~ {np.std(fwd_times):.4f}\n" | |
f" - bwd: {np.mean(bwd_times):.4f} ~ {np.std(bwd_times):.4f}") | |
qkv1_base = torch.empty((bs[0], seq_len[0], 3, num_head, head_dim), dtype=dtype, device=device) | |
qkv2_base = torch.empty((bs[1], seq_len[1], 3, num_head, head_dim), dtype=dtype, device=device) | |
grad1_base = torch.empty((bs[0], seq_len[0], num_head, head_dim), dtype=dtype, device=device) | |
grad2_base = torch.empty((bs[1], seq_len[1], num_head, head_dim), dtype=dtype, device=device) | |
fwd_times, bwd_times = [], [] | |
for _ in range(2): | |
torch.rand(qkv1_base.size(), device=device, dtype=dtype, out=qkv1_base) | |
torch.rand(qkv2_base.size(), device=device, dtype=dtype, out=qkv2_base) | |
torch.rand(grad1_base.size(), device=device, dtype=dtype, out=grad1_base) | |
torch.rand(grad2_base.size(), device=device, dtype=dtype, out=grad2_base) | |
qkv1 = qkv1_base.detach().requires_grad_(True) | |
qkv2 = qkv2_base.detach().requires_grad_(True) | |
grad1 = grad1_base.detach() | |
grad2 = grad2_base.detach() | |
out = flash_attn_qkvpacked_func(qkv1) | |
out.backward(grad1) | |
out = flash_attn_qkvpacked_func(qkv2) | |
out.backward(grad2) | |
for _ in range(10): | |
torch.rand(qkv1_base.size(), device=device, dtype=dtype, out=qkv1_base) | |
torch.rand(qkv2_base.size(), device=device, dtype=dtype, out=qkv2_base) | |
torch.rand(grad1_base.size(), device=device, dtype=dtype, out=grad1_base) | |
torch.rand(grad2_base.size(), device=device, dtype=dtype, out=grad2_base) | |
qkv1 = qkv1_base.detach().requires_grad_(True) | |
qkv2 = qkv2_base.detach().requires_grad_(True) | |
grad1 = grad1_base.detach() | |
grad2 = grad2_base.detach() | |
torch.cuda.synchronize() | |
start = time.time() | |
out = flash_attn_qkvpacked_func(qkv1) | |
torch.cuda.synchronize() | |
fwd_time = (time.time() - start) | |
start = time.time() | |
out.backward(grad1) | |
torch.cuda.synchronize() | |
bwd_time = (time.time() - start) | |
torch.cuda.synchronize() | |
start = time.time() | |
out = flash_attn_qkvpacked_func(qkv2) | |
torch.cuda.synchronize() | |
fwd_time += (time.time() - start) | |
start = time.time() | |
out.backward(grad2) | |
torch.cuda.synchronize() | |
bwd_time += (time.time() - start) | |
fwd_times.append(fwd_time) | |
bwd_times.append(bwd_time) | |
print(f"({keys[i]}*{nf1}, {keys[j]}*{nf2}) vanilla:\n - fwd: {np.mean(fwd_times):.4f} ~ {np.std(fwd_times):.4f}\n" | |
f" - bwd: {np.mean(bwd_times):.4f} ~ {np.std(bwd_times):.4f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment