Created
March 5, 2025 18:47
-
-
Save alexarmbr/4760c50e10b27c27142301b70d9eeba9 to your computer and use it in GitHub Desktop.
a minimal correctness test and benchmark of ulysses style parallel attention from ParaAttention
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
test performance and correctness of ulysses parallel attention vs single gpu attention | |
torchrun --nproc-per-node 2 benchmark_attn.py | |
using two H100s I get: | |
Rank 0 single gpu attention: 1698.14 ms | |
Rank 0 ulysses attention: 912.84 ms | |
running pip install para-attn should install everything needed | |
""" | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from para_attn.para_attn_interface import ulysses_attn_func | |
dist.init_process_group(backend='nccl') | |
torch.cuda.set_device(dist.get_rank()) | |
torch.manual_seed(42) | |
world_size = dist.get_world_size() | |
rank = dist.get_rank() | |
N_ITERS = 5 | |
B = 1 | |
H = 12 | |
S = 108540 | |
D = 128 | |
################################## | |
# benchmark single gpu attention # | |
################################## | |
print(f"Rank {dist.get_rank()} on device {torch.cuda.current_device()}") | |
# because the seed is set, we will generate the same q, k, v on all ranks | |
q = torch.randn(B, H, S, D, device='cuda') | |
k = torch.randn(B, H, S, D, device='cuda') | |
v = torch.randn(B, H, S, D, device='cuda') | |
# warmup | |
for i in range(N_ITERS): | |
F.scaled_dot_product_attention(q, k, v) | |
# sync across all ranks | |
dist.barrier() | |
# time N_ITERS of single gpu attention | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for i in range(N_ITERS): | |
local_out = F.scaled_dot_product_attention(q, k, v) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS | |
if rank == 0: | |
print(f"Rank {dist.get_rank()} single gpu attention: {elapsed_time:.2f} ms") | |
dist.barrier() | |
################################ | |
# benchmark parallel attention # | |
################################ | |
del start_event, end_event | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
# split q, k, v evenly across all ranks along the sequence length dimension | |
# in preparation for ulysses attention | |
local_q = q.chunk(world_size, dim=2)[rank] | |
local_k = k.chunk(world_size, dim=2)[rank] | |
local_v = v.chunk(world_size, dim=2)[rank] | |
# time N_ITERS of ulysses attention | |
start_event.record() | |
for i in range(N_ITERS): | |
ulysses_out = ulysses_attn_func(local_q, local_k, local_v) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS | |
if rank == 0: | |
print(f"Rank {dist.get_rank()} ulysses attention: {elapsed_time:.2f} ms") | |
# the output on rank is split along the sequence length dimension | |
# so we need to gather and concatenate the outputs across all ranks | |
out = [torch.zeros_like(ulysses_out) for _ in range(world_size)] | |
dist.all_gather(tensor_list=out, tensor=ulysses_out) | |
# now each rank has the full output, so we can compare to computing the full output on a single gpu | |
out = torch.cat(out, dim=2) | |
assert torch.allclose(out, local_out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment