-
-
Save alexarmbr/a32ebdba8b5287a44b94e190391b7e3f to your computer and use it in GitHub Desktop.
""" | |
test performance and correctness of ring attention vs. single gpu attention | |
torchrun --nproc-per-node 4 ring_attn.py | |
using 4 H100s I get: | |
Rank 0 single gpu attention: 261.78 ms | |
Rank 0 ring attention: 73.34 ms | |
""" | |
import os | |
import math | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from flash_attn.flash_attn_interface import _flash_attn_forward | |
N_ITERS = 5 | |
B = 1 # batch size | |
H = 16 # number of heads | |
S = 108540 # sequence length | |
D = 128 # head dimension | |
flash_attn_kwargs = { | |
'causal': False, | |
'dropout_p': 0.0, | |
'softmax_scale': 1 / math.sqrt(D), | |
'window_size_left': -1, | |
'window_size_right': -1, | |
'softcap': 0.0, | |
'alibi_slopes': None, | |
'return_softmax': False | |
} | |
def update_out_and_lse(acc_out, new_out, acc_lse, new_lse): | |
# our accumulator is B, H, S, D | |
# but flash_attn returns B, S, H, D | |
# so transpose new_out to match accumulator | |
new_out = new_out.transpose(1, 2) | |
# add dummy head head dimension to new lse so that broadcasting works | |
new_lse = new_lse.unsqueeze(-1) | |
# from here: | |
# https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/utils.py#L33 | |
lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse)) | |
out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out | |
return out, lse | |
def ring_attn_func(q, k, v): | |
# permute q,k,v from B, H, S, D to B, S, H, D for compatibility with flash_attn | |
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) | |
b, s, h, d = q.shape | |
my_rank = dist.get_rank() | |
world_size = dist.get_world_size() | |
# output accumulator | |
acc_out = torch.zeros(b, h, s, d, dtype=q.dtype, device=q.device) | |
# log sum exp accumulator (log of softmax denominator) | |
acc_lse = torch.zeros(b, h, s, 1, dtype=q.dtype, device=q.device) | |
send_rank = (my_rank + 1) % world_size | |
recv_rank = (my_rank - 1) % world_size | |
last_iteration = False | |
for i in range(world_size): | |
if i == world_size - 1: | |
last_iteration = True | |
# concat qkv for convenience, we only need to move one tensor rather than qkv individually | |
kv_send = torch.stack([k, v], dim=0).contiguous() | |
kv_recv = torch.zeros_like(kv_send) | |
# asychronous send to next rank, asychronous recv from previous rank | |
# you have to alternate like this otherwise isend/irecv will deadlockß | |
if not last_iteration: | |
if my_rank % 2 == 0: | |
send_tag = dist.isend(kv_send, dst = send_rank) | |
recv_tag = dist.irecv(kv_recv, src = recv_rank) | |
else: | |
recv_tag = dist.irecv(kv_recv, src = recv_rank) | |
send_tag = dist.isend(kv_send, dst = send_rank) | |
# use flash_attn, rather than F.scaled_dot_product_attention because we need the log sum exp of each chunk | |
# call flash_attn on local slice of q, k, v | |
out, lse, _, _ = _flash_attn_forward(q, k, v, **flash_attn_kwargs) | |
# update accumulators using the flash attention update rule | |
acc_out, acc_lse = update_out_and_lse(acc_out, out, acc_lse, lse) | |
if not last_iteration: | |
# wait for send and recv to complete | |
# ideally by the time the flash attn call completes, the send/recv will have completed, so we don't need to wait | |
send_tag.wait() | |
recv_tag.wait() | |
k, v = kv_recv.chunk(2, dim=0) | |
k = k.squeeze(0) | |
v = v.squeeze(0) | |
# permute acc_out back to B, H, S, D | |
acc_out = acc_out.contiguous().to(q.dtype) | |
return acc_out | |
dist.init_process_group(backend='nccl') | |
torch.cuda.set_device(dist.get_rank()) | |
torch.manual_seed(42) | |
world_size = dist.get_world_size() | |
rank = dist.get_rank() | |
################################## | |
# benchmark single gpu attention # | |
################################## | |
# because the seed is set, we will generate the same q, k, v on all ranks | |
q = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda') | |
k = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda') | |
v = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda') | |
# warmup | |
for i in range(N_ITERS): | |
F.scaled_dot_product_attention(q, k, v) | |
# sync across all ranks | |
dist.barrier(device_ids = [rank]) | |
# time N_ITERS of single gpu attention | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for i in range(N_ITERS): | |
local_out = F.scaled_dot_product_attention(q, k, v) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS | |
if rank == 0: | |
print(f"Rank {dist.get_rank()} single gpu attention: {elapsed_time:.2f} ms") | |
dist.barrier(device_ids = [rank]) | |
################################ | |
# benchmark parallel attention # | |
################################ | |
# del start_event, end_event | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
# split q, k, v evenly across all ranks along the sequence length dimension | |
# in preparation for ulysses attention | |
local_q = q.chunk(world_size, dim=2)[rank] | |
local_k = k.chunk(world_size, dim=2)[rank] | |
local_v = v.chunk(world_size, dim=2)[rank] | |
# warmup | |
for i in range(N_ITERS): | |
ring_out = ring_attn_func(local_q, local_k, local_v) | |
# time N_ITERS of ulysses attention | |
start_event.record() | |
for i in range(N_ITERS): | |
ring_out = ring_attn_func(local_q, local_k, local_v) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS | |
if rank == 0: | |
print(f"Rank {dist.get_rank()} ulysses attention: {elapsed_time:.2f} ms") | |
# the output on rank is split along the sequence length dimension | |
# so we need to gather and concatenate the outputs across all ranks | |
out = [torch.zeros_like(ring_out).contiguous() for _ in range(world_size)] | |
dist.all_gather(tensor_list=out, tensor=ring_out) | |
# now each rank has the full output, so we can compare to computing the full output on a single gpu | |
out = torch.cat(out, dim=2) | |
# since we are using flash attention style accumulation, results will differ slightly | |
assert torch.allclose(out, local_out, atol=1e-2) | |
dist.destroy_process_group() |
What exactly is update_out_and_lse
doing?
Ring attention and flash attention both compute attention block by block. The only difference is that in flash attention, a block is local to a particular streaming multiprocessor, whereas in ring attention a block is local to a particular GPU. So ring attention is basically a distributed version of flash attention. The genius of flash attention is that it allows us to compute the softmax denominator in a block by block manner, with only a single pass through each row, using some clever recurrence relations. We can use the exact same math in ring attention.
According to the guy himself, the lse
returned from _flash_attn_func
is the natural log of the softmax denominator. Specifically, if we are computing the softmax of a vector
Since we are computing attention one block at a time, we never have all the elements of
def update_out_and_lse(acc_out, new_out, acc_lse, new_lse):
# our accumulator is B, H, S, D
# but flash_attn returns B, H, S, D
# so transpose new_out to match accumulator
new_out = new_out.transpose(1, 2)
# add dummy head head dimension to new lse so that broadcasting works
new_lse = new_lse.unsqueeze(-1)
lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse))
out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out
return out, lse
Some Derivations
what is this line doing?
lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse))
lets call acc_lse
new_lse
then we want to compute
If our code (which I got from here) is correct, then
Here is a derivation that shows it is computing what we want:
So its correct!
What about:
out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out
This line is combining the output we have accumulated so far, with the new output, using the LSE terms to reweight the contributions. Lets call acc_out
new_out
So the
and the
and we want to compute the
Our code claims that this is equivalent to:
plugging in the LSE terms, we have:
and plugging in for
and things cancel out to give us what we want:
conclusion
If you look at this comment on zhuzilin's implementation, you can see that these two lines
lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse))
out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out
are equivalent to
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
This latter form is a bit better because
- there is no
exp
, instead asigmoid
. Thesigmoid
has anexp
inside it, but there are some tricks for computing this with maximum numerical stability, and we can assume that thetorch
implementation is using these tricks. - it is probably a bit faster, because there is no data dependency between the
out
andlse
, as there is in the version that we derived. Soout
andlse
can be computed at the same time, which will utilize more of the GPU. But as the nsight systems trace shows, the compute from this update is tiny compared to the compute from the attention, so this doesn't really matter in the grand scheme of things.
using nsight systems, we can see nccl kernels and flash attention kernel on two separate streams, overlapping with each other. Compute is not stalled waiting for data to arrive, which means this is at least somewhat efficient. Not sure why the nccl operation on the first iteration is tiny compared to on the subsequent iterations.