Skip to content

Instantly share code, notes, and snippets.

@alexdremov
Created January 12, 2025 22:24
Show Gist options
  • Save alexdremov/22d1e01ecd5eab923506380aee2204cd to your computer and use it in GitHub Desktop.
Save alexdremov/22d1e01ecd5eab923506380aee2204cd to your computer and use it in GitHub Desktop.
Code snippet uploaded via Python script (py)
def self_attn_fwd(...):
# loading sample len
seq_len = ...
# running qk^T max (initialized by -inf)
m_i = tl.zeros([TILE_Q_SIZE], dtype=tl.float32) - float("inf")
# current softmax denominator
l_i = tl.zeros([TILE_Q_SIZE], dtype=tl.float32)
# result tile
# we will accumulate here (softmax numerator) @ V
# then, we will divide it by softmax denominator in the very end
acc = tl.zeros([TILE_Q_SIZE, HEAD_DIM], dtype=tl.float32)
# notice: we accumulate all values above
# in fp32 for higher precision
# account for variable length of samples in batch
q_tile_indices = q_token_idx + tl.arange(0, TILE_Q_SIZE)
q_lens_mask = (
q_tile_indices[:, None] < seq_len
)
# loading q tile into SRAM, shape (TILE_Q_SIZE, HEAD_DIM)
q_tile = ...
# softmax scale, multiplying by log_2(e)
# to use faster exp2(...) instead of exp(...)
softmax_scale: tl.constexpr = tl.cast(SM_SCALE * log_2(e), q_tile.dtype)
# indices of tokens inside kv tile
tile_k_arange = tl.arange(0, TILE_K_SIZE)
# iterate over all tiles in k, v
for kv_tile_idx in tl.range(
0, tl.cdiv(seq_len, TILE_K_SIZE), num_stages=PIPELINING
):
# index of the first token in the kv tile
kv_token_idx = kv_tile_idx * TILE_K_SIZE
kt_tile = ... # load into SRAM K^T tile no. kv_tile_idx
v_tile = ... # load into SRAM V tile no. kv_tile_idx
# compute tile of QK^T
qk = tl.dot(
q_tile * softmax_scale,
kt_tile,
input_precision=INPUT_PRECISION,
out_dtype=tl.float32
)
# masking out kv tokens after the sequence length
kv_indices = kv_token_idx + tile_k_arange
mask = q_lens_mask & (
kv_indices[None, :] < seq_len
)
# set masked out values to -inf
# for softmax to ignore them
qk = tl.where(mask, qk, tl.cast(-float("inf"), qk.dtype))
# calculating new maximum over seq len
# m(x) = m(m(x1), m(x2))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
# e^(x2 - m(x))
p = tl.math.exp2(qk - m_ij[:, None])
# current tile softmax denominator
l_ij = tl.sum(p, 1)
# from softmax formula: e^(m(x1) - m(x))
alpha = tl.math.exp2(m_i - m_ij)
# updating denominator using the formula
# l(x) = e^(m(x1) - m(x)) * l(x1) + e^(0)l(x2)
# notice: e^(0) as we subtract m(x) from x2 above
l_i = l_i * alpha + l_ij
# update previous acc to address maximum change
# as e^(xi - m(x1)) * alpha = e^(xi - m(x))
acc = acc * alpha[:, None]
# multiply p by v and adding to acc
acc += tl.dot(
p.to(v_tile.dtype),
v_tile,
input_precision=INPUT_PRECISION,
out_dtype=tl.float32,
)
# storing new maximum
m_i = m_ij
# finally incorporate softmax denominator
acc = acc / l_i[:, None]
# set fully masked token values to 0 to avoid garbage values
# in the result
acc = tl.where(q_lens_mask, acc, 0.0)
# save the result
tl.save(acc, ...)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment