Last active
March 15, 2025 05:59
-
-
Save alexarmbr/a32ebdba8b5287a44b94e190391b7e3f to your computer and use it in GitHub Desktop.
Ring-Flash Attention
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
test performance and correctness of ring attention vs. single gpu attention | |
torchrun --nproc-per-node 4 ring_attn.py | |
using 4 H100s I get: | |
Rank 0 single gpu attention: 261.78 ms | |
Rank 0 ring attention: 73.34 ms | |
""" | |
import os | |
import math | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from flash_attn.flash_attn_interface import _flash_attn_forward | |
N_ITERS = 5 | |
B = 1 # batch size | |
H = 16 # number of heads | |
S = 108540 # sequence length | |
D = 128 # head dimension | |
flash_attn_kwargs = { | |
'causal': False, | |
'dropout_p': 0.0, | |
'softmax_scale': 1 / math.sqrt(D), | |
'window_size_left': -1, | |
'window_size_right': -1, | |
'softcap': 0.0, | |
'alibi_slopes': None, | |
'return_softmax': False | |
} | |
def update_out_and_lse(acc_out, new_out, acc_lse, new_lse): | |
# our accumulator is B, H, S, D | |
# but flash_attn returns B, S, H, D | |
# so transpose new_out to match accumulator | |
new_out = new_out.transpose(1, 2) | |
# add dummy head head dimension to new lse so that broadcasting works | |
new_lse = new_lse.unsqueeze(-1) | |
# from here: | |
# https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/utils.py#L33 | |
lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse)) | |
out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out | |
return out, lse | |
def ring_attn_func(q, k, v): | |
# permute q,k,v from B, H, S, D to B, S, H, D for compatibility with flash_attn | |
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) | |
b, s, h, d = q.shape | |
my_rank = dist.get_rank() | |
world_size = dist.get_world_size() | |
# output accumulator | |
acc_out = torch.zeros(b, h, s, d, dtype=q.dtype, device=q.device) | |
# log sum exp accumulator (log of softmax denominator) | |
acc_lse = torch.zeros(b, h, s, 1, dtype=q.dtype, device=q.device) | |
send_rank = (my_rank + 1) % world_size | |
recv_rank = (my_rank - 1) % world_size | |
last_iteration = False | |
for i in range(world_size): | |
if i == world_size - 1: | |
last_iteration = True | |
# concat qkv for convenience, we only need to move one tensor rather than qkv individually | |
kv_send = torch.stack([k, v], dim=0).contiguous() | |
kv_recv = torch.zeros_like(kv_send) | |
# asychronous send to next rank, asychronous recv from previous rank | |
# you have to alternate like this otherwise isend/irecv will deadlockß | |
if not last_iteration: | |
if my_rank % 2 == 0: | |
send_tag = dist.isend(kv_send, dst = send_rank) | |
recv_tag = dist.irecv(kv_recv, src = recv_rank) | |
else: | |
recv_tag = dist.irecv(kv_recv, src = recv_rank) | |
send_tag = dist.isend(kv_send, dst = send_rank) | |
# use flash_attn, rather than F.scaled_dot_product_attention because we need the log sum exp of each chunk | |
# call flash_attn on local slice of q, k, v | |
out, lse, _, _ = _flash_attn_forward(q, k, v, **flash_attn_kwargs) | |
# update accumulators using the flash attention update rule | |
acc_out, acc_lse = update_out_and_lse(acc_out, out, acc_lse, lse) | |
if not last_iteration: | |
# wait for send and recv to complete | |
# ideally by the time the flash attn call completes, the send/recv will have completed, so we don't need to wait | |
send_tag.wait() | |
recv_tag.wait() | |
k, v = kv_recv.chunk(2, dim=0) | |
k = k.squeeze(0) | |
v = v.squeeze(0) | |
# permute acc_out back to B, H, S, D | |
acc_out = acc_out.contiguous().to(q.dtype) | |
return acc_out | |
dist.init_process_group(backend='nccl') | |
torch.cuda.set_device(dist.get_rank()) | |
torch.manual_seed(42) | |
world_size = dist.get_world_size() | |
rank = dist.get_rank() | |
################################## | |
# benchmark single gpu attention # | |
################################## | |
# because the seed is set, we will generate the same q, k, v on all ranks | |
q = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda') | |
k = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda') | |
v = torch.randn(B, H, S, D, dtype=torch.bfloat16, device='cuda') | |
# warmup | |
for i in range(N_ITERS): | |
F.scaled_dot_product_attention(q, k, v) | |
# sync across all ranks | |
dist.barrier(device_ids = [rank]) | |
# time N_ITERS of single gpu attention | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for i in range(N_ITERS): | |
local_out = F.scaled_dot_product_attention(q, k, v) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS | |
if rank == 0: | |
print(f"Rank {dist.get_rank()} single gpu attention: {elapsed_time:.2f} ms") | |
dist.barrier(device_ids = [rank]) | |
################################ | |
# benchmark parallel attention # | |
################################ | |
# del start_event, end_event | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
# split q, k, v evenly across all ranks along the sequence length dimension | |
# in preparation for ulysses attention | |
local_q = q.chunk(world_size, dim=2)[rank] | |
local_k = k.chunk(world_size, dim=2)[rank] | |
local_v = v.chunk(world_size, dim=2)[rank] | |
# warmup | |
for i in range(N_ITERS): | |
ring_out = ring_attn_func(local_q, local_k, local_v) | |
# time N_ITERS of ulysses attention | |
start_event.record() | |
for i in range(N_ITERS): | |
ring_out = ring_attn_func(local_q, local_k, local_v) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_time = start_event.elapsed_time(end_event) / N_ITERS | |
if rank == 0: | |
print(f"Rank {dist.get_rank()} ulysses attention: {elapsed_time:.2f} ms") | |
# the output on rank is split along the sequence length dimension | |
# so we need to gather and concatenate the outputs across all ranks | |
out = [torch.zeros_like(ring_out).contiguous() for _ in range(world_size)] | |
dist.all_gather(tensor_list=out, tensor=ring_out) | |
# now each rank has the full output, so we can compare to computing the full output on a single gpu | |
out = torch.cat(out, dim=2) | |
# since we are using flash attention style accumulation, results will differ slightly | |
assert torch.allclose(out, local_out, atol=1e-2) | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
What exactly is
update_out_and_lse
doing?Ring attention and flash attention both compute attention block by block. The only difference is that in flash attention, a block is local to a particular streaming multiprocessor, whereas in ring attention a block is local to a particular GPU. So ring attention is basically a distributed version of flash attention. The genius of flash attention is that it allows us to compute the softmax denominator in a block by block manner, with only a single pass through each row, using some clever recurrence relations. We can use the exact same math in ring attention.
According to the guy himself, the$x_i$ , then the softmax denominator is $\sum e^{x_i}$ , and the natural log of the softmax denominator is $lse(x_i) = \log \sum e^{x_i}$ .
lse
returned from_flash_attn_func
is the natural log of the softmax denominator. Specifically, if we are computing the softmax of a vectorSince we are computing attention one block at a time, we never have all the elements of$x_i$ in memory at once. So given two blocks of the output matrix $O^A$ and $O^B$ , we need a way to combine them to get the combined block $O^{[A,B]}$ . In order to do this we need the log sum exp terms $LSE^A$ and $LSE^B$ for each block.
Some Derivations
what is this line doing?
lse = acc_lse + torch.log(1 + torch.exp(new_lse - acc_lse))
lets call$LSE^A$ and $LSE^B$ . We want to compute the log sum exp of the combined block $LSE^{[A,B]}$ . So if:
acc_lse
new_lse
then we want to compute
If our code (which I got from here) is correct, then
Here is a derivation that shows it is computing what we want:
So its correct!
What about:
out = torch.exp(acc_lse - lse) * acc_out + torch.exp(new_lse - lse) * new_out
This line is combining the output we have accumulated so far, with the new output, using the LSE terms to reweight the contributions. Lets call$O^A$ and $O^B$ , and lets only the scalar case for simplicity, because the important part is how this expression is normalizing the denominator. Each element of $O^A$ and $O^B$ is a single softmax score, times a row of $v$ .
acc_out
new_out
So the$i^{th}$ element of $O^A$ is:
$$O^A_i = \frac{e^{x^A_i}}{\sum_j e^{x_j^A}} * v$$
and the$i^{th}$ element of $O^B$ is:
$$O^B_i = \frac{e^{x^A_i}}{\sum_j e^{x_j^B}} * v$$
and we want to compute the$i^{th}$ element of the combined output $O^{[A,B]}$ :
$$O^{[A,B]}_i = \frac{e^{x^A_i} + e^{x^B_i}}{\sum_j e^{x_j^B} + \sum_j e^{x_j^A}} * v$$
Our code claims that this is equivalent to:
$$O^{[A,B]}_i = e^{LSE^A - LSE^{[A,B]}} * O^A + e^{LSE^B - LSE^{[A,B]}} * O^B$$
plugging in the LSE terms, we have:
and plugging in for$O^A$ and $O^B$ , we have:
and things cancel out to give us what we want:
conclusion
If you look at this comment on zhuzilin's implementation, you can see that these two lines
are equivalent to
This latter form is a bit better because
exp
, instead asigmoid
. Thesigmoid
has anexp
inside it, but there are some tricks for computing this with maximum numerical stability, and we can assume that thetorch
implementation is using these tricks.out
andlse
, as there is in the version that we derived. Soout
andlse
can be computed at the same time, which will utilize more of the GPU. But as the nsight systems trace shows, the compute from this update is tiny compared to the compute from the attention, so this doesn't really matter in the grand scheme of things.