Skip to content

Instantly share code, notes, and snippets.

@alexarmbr
Created March 5, 2025 18:47
Show Gist options
  • Save alexarmbr/4760c50e10b27c27142301b70d9eeba9 to your computer and use it in GitHub Desktop.
Save alexarmbr/4760c50e10b27c27142301b70d9eeba9 to your computer and use it in GitHub Desktop.
a minimal correctness test and benchmark of ulysses style parallel attention from ParaAttention
"""
test performance and correctness of ulysses parallel attention vs single gpu attention
torchrun --nproc-per-node 2 benchmark_attn.py
using two H100s I get:
Rank 0 single gpu attention: 1698.14 ms
Rank 0 ulysses attention: 912.84 ms
running pip install para-attn should install everything needed
"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from para_attn.para_attn_interface import ulysses_attn_func
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank())
torch.manual_seed(42)
world_size = dist.get_world_size()
rank = dist.get_rank()
N_ITERS = 5
B = 1
H = 12
S = 108540
D = 128
##################################
# benchmark single gpu attention #
##################################
print(f"Rank {dist.get_rank()} on device {torch.cuda.current_device()}")
# because the seed is set, we will generate the same q, k, v on all ranks
q = torch.randn(B, H, S, D, device='cuda')
k = torch.randn(B, H, S, D, device='cuda')
v = torch.randn(B, H, S, D, device='cuda')
# warmup
for i in range(N_ITERS):
F.scaled_dot_product_attention(q, k, v)
# sync across all ranks
dist.barrier()
# time N_ITERS of single gpu attention
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(N_ITERS):
local_out = F.scaled_dot_product_attention(q, k, v)
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS
if rank == 0:
print(f"Rank {dist.get_rank()} single gpu attention: {elapsed_time:.2f} ms")
dist.barrier()
################################
# benchmark parallel attention #
################################
del start_event, end_event
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# split q, k, v evenly across all ranks along the sequence length dimension
# in preparation for ulysses attention
local_q = q.chunk(world_size, dim=2)[rank]
local_k = k.chunk(world_size, dim=2)[rank]
local_v = v.chunk(world_size, dim=2)[rank]
# time N_ITERS of ulysses attention
start_event.record()
for i in range(N_ITERS):
ulysses_out = ulysses_attn_func(local_q, local_k, local_v)
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS
if rank == 0:
print(f"Rank {dist.get_rank()} ulysses attention: {elapsed_time:.2f} ms")
# the output on rank is split along the sequence length dimension
# so we need to gather and concatenate the outputs across all ranks
out = [torch.zeros_like(ulysses_out) for _ in range(world_size)]
dist.all_gather(tensor_list=out, tensor=ulysses_out)
# now each rank has the full output, so we can compare to computing the full output on a single gpu
out = torch.cat(out, dim=2)
assert torch.allclose(out, local_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment