Created
January 12, 2025 22:24
-
-
Save alexdremov/22d1e01ecd5eab923506380aee2204cd to your computer and use it in GitHub Desktop.
Code snippet uploaded via Python script (py)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def self_attn_fwd(...): | |
# loading sample len | |
seq_len = ... | |
# running qk^T max (initialized by -inf) | |
m_i = tl.zeros([TILE_Q_SIZE], dtype=tl.float32) - float("inf") | |
# current softmax denominator | |
l_i = tl.zeros([TILE_Q_SIZE], dtype=tl.float32) | |
# result tile | |
# we will accumulate here (softmax numerator) @ V | |
# then, we will divide it by softmax denominator in the very end | |
acc = tl.zeros([TILE_Q_SIZE, HEAD_DIM], dtype=tl.float32) | |
# notice: we accumulate all values above | |
# in fp32 for higher precision | |
# account for variable length of samples in batch | |
q_tile_indices = q_token_idx + tl.arange(0, TILE_Q_SIZE) | |
q_lens_mask = ( | |
q_tile_indices[:, None] < seq_len | |
) | |
# loading q tile into SRAM, shape (TILE_Q_SIZE, HEAD_DIM) | |
q_tile = ... | |
# softmax scale, multiplying by log_2(e) | |
# to use faster exp2(...) instead of exp(...) | |
softmax_scale: tl.constexpr = tl.cast(SM_SCALE * log_2(e), q_tile.dtype) | |
# indices of tokens inside kv tile | |
tile_k_arange = tl.arange(0, TILE_K_SIZE) | |
# iterate over all tiles in k, v | |
for kv_tile_idx in tl.range( | |
0, tl.cdiv(seq_len, TILE_K_SIZE), num_stages=PIPELINING | |
): | |
# index of the first token in the kv tile | |
kv_token_idx = kv_tile_idx * TILE_K_SIZE | |
kt_tile = ... # load into SRAM K^T tile no. kv_tile_idx | |
v_tile = ... # load into SRAM V tile no. kv_tile_idx | |
# compute tile of QK^T | |
qk = tl.dot( | |
q_tile * softmax_scale, | |
kt_tile, | |
input_precision=INPUT_PRECISION, | |
out_dtype=tl.float32 | |
) | |
# masking out kv tokens after the sequence length | |
kv_indices = kv_token_idx + tile_k_arange | |
mask = q_lens_mask & ( | |
kv_indices[None, :] < seq_len | |
) | |
# set masked out values to -inf | |
# for softmax to ignore them | |
qk = tl.where(mask, qk, tl.cast(-float("inf"), qk.dtype)) | |
# calculating new maximum over seq len | |
# m(x) = m(m(x1), m(x2)) | |
m_ij = tl.maximum(m_i, tl.max(qk, 1)) | |
# e^(x2 - m(x)) | |
p = tl.math.exp2(qk - m_ij[:, None]) | |
# current tile softmax denominator | |
l_ij = tl.sum(p, 1) | |
# from softmax formula: e^(m(x1) - m(x)) | |
alpha = tl.math.exp2(m_i - m_ij) | |
# updating denominator using the formula | |
# l(x) = e^(m(x1) - m(x)) * l(x1) + e^(0)l(x2) | |
# notice: e^(0) as we subtract m(x) from x2 above | |
l_i = l_i * alpha + l_ij | |
# update previous acc to address maximum change | |
# as e^(xi - m(x1)) * alpha = e^(xi - m(x)) | |
acc = acc * alpha[:, None] | |
# multiply p by v and adding to acc | |
acc += tl.dot( | |
p.to(v_tile.dtype), | |
v_tile, | |
input_precision=INPUT_PRECISION, | |
out_dtype=tl.float32, | |
) | |
# storing new maximum | |
m_i = m_ij | |
# finally incorporate softmax denominator | |
acc = acc / l_i[:, None] | |
# set fully masked token values to 0 to avoid garbage values | |
# in the result | |
acc = tl.where(q_lens_mask, acc, 0.0) | |
# save the result | |
tl.save(acc, ...) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment