Skip to content

Instantly share code, notes, and snippets.

@alexarmbr
Last active March 15, 2025 05:59
Show Gist options
  • Save alexarmbr/a32ebdba8b5287a44b94e190391b7e3f to your computer and use it in GitHub Desktop.
Save alexarmbr/a32ebdba8b5287a44b94e190391b7e3f to your computer and use it in GitHub Desktop.
Ring-Flash Attention
"""
test performance and correctness of ring attention vs. single gpu attention
torchrun --nproc-per-node 4 ring_attn.py
using 4 H100s I get:
Rank 0 single gpu attention: 261.78 ms
Rank 0 ring attention: 73.34 ms
"""
import os
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from flash_attn.flash_attn_interface import _flash_attn_forward
N_ITERS = 5
B = 1 # batch size
H = 16 # number of heads
S = 108540 # sequence length
D = 128 # head dimension
flash_attn_kwargs = {
'causal': False,
'dropout_p': 0.0,
'softmax_scale': 1 / math.sqrt(D),
'window_size_left': -1,
'window_size_right': -1,
'softcap': 0.0,
'alibi_slopes': None,
'return_softmax': False
}
def update_out_and_lse(acc_out, new_out, acc_lse, new_lse):
# our accumulator is B, H, S, D
# but flash_attn returns B, S, H, D
# so transpose new_out to match accumulator
new_out = new_out.transpose(1, 2)
# add dummy head head dimension to new lse so that broadcasting works
new_lse = new_lse.unsqueeze(-1)
# from here:
# https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/utils.py#L33
lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse))
out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out
return out, lse
def ring_attn_func(q, k, v):
# permute q,k,v from B, H, S, D to B, S, H, D for compatibility with flash_attn
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
b, s, h, d = q.shape
my_rank = dist.get_rank()
world_size = dist.get_world_size()
# output accumulator
acc_out = torch.zeros(b, h, s, d, dtype=q.dtype, device=q.device)
# log sum exp accumulator (log of softmax denominator)
acc_lse = torch.zeros(b, h, s, 1, dtype=q.dtype, device=q.device)
send_rank = (my_rank + 1) % world_size
recv_rank = (my_rank - 1) % world_size
last_iteration = False
for i in range(world_size):
if i == world_size - 1:
last_iteration = True
# concat qkv for convenience, we only need to move one tensor rather than qkv individually
kv_send = torch.stack([k, v], dim=0).contiguous()
kv_recv = torch.zeros_like(kv_send)
# asychronous send to next rank, asychronous recv from previous rank
# you have to alternate like this otherwise isend/irecv will deadlockß
if not last_iteration:
if my_rank % 2 == 0:
send_tag = dist.isend(kv_send, dst = send_rank)
recv_tag = dist.irecv(kv_recv, src = recv_rank)
else:
recv_tag = dist.irecv(kv_recv, src = recv_rank)
send_tag = dist.isend(kv_send, dst = send_rank)
# use flash_attn, rather than F.scaled_dot_product_attention because we need the log sum exp of each chunk
# call flash_attn on local slice of q, k, v
out, lse, _, _ = _flash_attn_forward(q, k, v, **flash_attn_kwargs)
# update accumulators using the flash attention update rule
acc_out, acc_lse = update_out_and_lse(acc_out, out, acc_lse, lse)
if not last_iteration:
# wait for send and recv to complete
# ideally by the time the flash attn call completes, the send/recv will have completed, so we don't need to wait
send_tag.wait()
recv_tag.wait()
k, v = kv_recv.chunk(2, dim=0)
k = k.squeeze(0)
v = v.squeeze(0)
# permute acc_out back to B, H, S, D
acc_out = acc_out.contiguous().to(q.dtype)
return acc_out
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank())
torch.manual_seed(42)
world_size = dist.get_world_size()
rank = dist.get_rank()
##################################
# benchmark single gpu attention #
##################################
# because the seed is set, we will generate the same q, k, v on all ranks
q = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda')
k = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda')
v = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda')
# warmup
for i in range(N_ITERS):
F.scaled_dot_product_attention(q, k, v)
# sync across all ranks
dist.barrier(device_ids = [rank])
# time N_ITERS of single gpu attention
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(N_ITERS):
local_out = F.scaled_dot_product_attention(q, k, v)
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS
if rank == 0:
print(f"Rank {dist.get_rank()} single gpu attention: {elapsed_time:.2f} ms")
dist.barrier(device_ids = [rank])
################################
# benchmark parallel attention #
################################
# del start_event, end_event
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# split q, k, v evenly across all ranks along the sequence length dimension
# in preparation for ulysses attention
local_q = q.chunk(world_size, dim=2)[rank]
local_k = k.chunk(world_size, dim=2)[rank]
local_v = v.chunk(world_size, dim=2)[rank]
# warmup
for i in range(N_ITERS):
ring_out = ring_attn_func(local_q, local_k, local_v)
# time N_ITERS of ulysses attention
start_event.record()
for i in range(N_ITERS):
ring_out = ring_attn_func(local_q, local_k, local_v)
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS
if rank == 0:
print(f"Rank {dist.get_rank()} ulysses attention: {elapsed_time:.2f} ms")
# the output on rank is split along the sequence length dimension
# so we need to gather and concatenate the outputs across all ranks
out = [torch.zeros_like(ring_out).contiguous() for _ in range(world_size)]
dist.all_gather(tensor_list=out, tensor=ring_out)
# now each rank has the full output, so we can compare to computing the full output on a single gpu
out = torch.cat(out, dim=2)
# since we are using flash attention style accumulation, results will differ slightly
assert torch.allclose(out, local_out, atol=1e-2)
dist.destroy_process_group()
@alexarmbr
Copy link
Author

alexarmbr commented Mar 13, 2025

using nsight systems, we can see nccl kernels and flash attention kernel on two separate streams, overlapping with each other. Compute is not stalled waiting for data to arrive, which means this is at least somewhat efficient. Not sure why the nccl operation on the first iteration is tiny compared to on the subsequent iterations.

ring_attn_ncu

@alexarmbr
Copy link
Author

alexarmbr commented Mar 13, 2025

What exactly is update_out_and_lse doing?

Ring attention and flash attention both compute attention block by block. The only difference is that in flash attention, a block is local to a particular streaming multiprocessor, whereas in ring attention a block is local to a particular GPU. So ring attention is basically a distributed version of flash attention. The genius of flash attention is that it allows us to compute the softmax denominator in a block by block manner, with only a single pass through each row, using some clever recurrence relations. We can use the exact same math in ring attention.

According to the guy himself, the lse returned from _flash_attn_func is the natural log of the softmax denominator. Specifically, if we are computing the softmax of a vector $x_i$, then the softmax denominator is $\sum e^{x_i}$, and the natural log of the softmax denominator is $lse(x_i) = \log \sum e^{x_i}$.

Since we are computing attention one block at a time, we never have all the elements of $x_i$ in memory at once. So given two blocks of the output matrix $O^A$ and $O^B$, we need a way to combine them to get the combined block $O^{[A,B]}$. In order to do this we need the log sum exp terms $LSE^A$ and $LSE^B$ for each block.

def update_out_and_lse(acc_out, new_out, acc_lse, new_lse):

    # our accumulator is B, H, S, D
    # but flash_attn returns B, H, S, D
    # so transpose new_out to match accumulator
    new_out = new_out.transpose(1, 2)
    
    # add dummy head head dimension to new lse so that broadcasting works
    new_lse = new_lse.unsqueeze(-1)


    lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse))
    out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out
    return out, lse

Some Derivations

what is this line doing?

lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse))

lets call acc_lse $LSE^A$ and new_lse $LSE^B$. We want to compute the log sum exp of the combined block $LSE^{[A,B]}$. So if:

$$LSE^A = \log\sum_i e^{x_i^A}$$ and $$LSE^B = \log\sum_i e^{x_i^B}$$

then we want to compute

$$LSE^{[A,B]} = \log(\sum_i e^{x_i^A} + \sum_i e^{x_i^B})$$

If our code (which I got from here) is correct, then

$$\log(\sum_i e^{x_i^A} + \sum_i e^{x_i^B}) = LSE^A + \log(1 + e^{LSE^B - LSE^A})$$

Here is a derivation that shows it is computing what we want:

$$LSE^{[A,B]} = LSE^A + \log(1 + e^{LSE^B - LSE^A})$$

$$LSE^{[A,B]} = LSE^A + \log(1 + \frac{e^{LSE^B}}{e^{LSE^A}})$$

$$LSE^{[A,B]} = LSE^A + \log(1 + \frac{e^{\log\sum_i e^{x_i^B}}}{e^{\log\sum_i e^{x_i^A}}})$$

$$LSE^{[A,B]} = LSE^A + \log(1 + \frac{\sum_i e^{x_i^B}}{\sum_i e^{x_i^A}})$$

$$LSE^{[A,B]} = LSE^A + \log(1 + \frac{\sum_i e^{x_i^B}}{\sum_i e^{x_i^A}})$$

$$LSE^{[A,B]} = LSE^A + \log(\frac{\sum_i e^{x_i^A} + \sum_i e^{x_i^B}}{\sum_i e^{x_i^A}})$$

$$LSE^{[A,B]} = LSE^A + \log(\sum_i e^{x_i^A} + \sum_i e^{x_i^B}) - \log(\sum_i e^{x_i^A})$$

$$LSE^{[A,B]} = LSE^A + \log(\sum_i e^{x_i^A} + \sum_i e^{x_i^B}) - LSE^A$$

$$LSE^{[A,B]} = \log(\sum_i e^{x_i^A} + \sum_i e^{x_i^B})$$

So its correct!

What about:

out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out

This line is combining the output we have accumulated so far, with the new output, using the LSE terms to reweight the contributions. Lets call acc_out $O^A$ and new_out $O^B$, and lets only the scalar case for simplicity, because the important part is how this expression is normalizing the denominator. Each element of $O^A$ and $O^B$ is a single softmax score, times a row of $v$.

So the $i^{th}$ element of $O^A$ is:
$$O^A_i = \frac{e^{x^A_i}}{\sum_j e^{x_j^A}} * v$$

and the $i^{th}$ element of $O^B$ is:
$$O^B_i = \frac{e^{x^A_i}}{\sum_j e^{x_j^B}} * v$$

and we want to compute the $i^{th}$ element of the combined output $O^{[A,B]}$:
$$O^{[A,B]}_i = \frac{e^{x^A_i} + e^{x^B_i}}{\sum_j e^{x_j^B} + \sum_j e^{x_j^A}} * v$$

Our code claims that this is equivalent to:
$$O^{[A,B]}_i = e^{LSE^A - LSE^{[A,B]}} * O^A + e^{LSE^B - LSE^{[A,B]}} * O^B$$

plugging in the LSE terms, we have:

$$O^{[A,B]}_i = e^{\log\sum_j e^{x_j^A} - \log(\sum_j e^{x_j^A} + \sum_j e^{x_j^B})} * O^A + e^{\log\sum_j e^{x_j^B} - \log(\sum_j e^{x_j^A} + \sum_j e^{x_j^B})} * O^B$$

$$O^{[A,B]}_i = \frac{e^{\log\sum_j e^{x_j^A}}}{e^{\log(\sum_j e^{x_j^A} + \sum_j e^{x_j^B})}} * O^A + \frac{e^{\log\sum_j e^{x_j^B}}}{e^{\log(\sum_j e^{x_j^A} + \sum_j e^{x_j^B})}} * O^B$$

$$O^{[A,B]}_i = \frac{\sum_j e^{x_j^A}}{\sum_j e^{x_j^A} + \sum_j e^{x_j^B}} * O^A + \frac{\sum_j e^{x_j^B}}{\sum_j e^{x_j^A} + \sum_j e^{x_j^B}} * O^B$$

and plugging in for $O^A$ and $O^B$, we have:

$$O^{[A,B]}_i = \frac{\sum_j e^{x_j^A}}{\sum_j e^{x_j^A} + \sum_j e^{x_j^B}} * \frac{e^{x^A_i}}{\sum_j e^{x_j^A}} * v + \frac{\sum_j e^{x_j^B}}{\sum_j e^{x_j^A} + \sum_j e^{x_j^B}} * \frac{e^{x^B_i}}{\sum_j e^{x_j^B}} * v$$

and things cancel out to give us what we want:

$$O^{[A,B]}_i = \frac{e^{x^A_i} + e^{x^B_i}}{\sum_j e^{x_j^A} + \sum_j e^{x_j^B}} * v$$

conclusion

If you look at this comment on zhuzilin's implementation, you can see that these two lines

lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse))
out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out

are equivalent to

out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)

This latter form is a bit better because

  • there is no exp, instead a sigmoid. The sigmoid has an exp inside it, but there are some tricks for computing this with maximum numerical stability, and we can assume that the torch implementation is using these tricks.
  • it is probably a bit faster, because there is no data dependency between the out and lse, as there is in the version that we derived. So out and lse can be computed at the same time, which will utilize more of the GPU. But as the nsight systems trace shows, the compute from this update is tiny compared to the compute from the attention, so this doesn't really matter in the grand scheme of things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment