Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created July 11, 2025 11:29
Show Gist options
  • Select an option

  • Save a-r-r-o-w/4cab6a83ee580376fbab4492986ad49d to your computer and use it in GitHub Desktop.

Select an option

Save a-r-r-o-w/4cab6a83ee580376fbab4492986ad49d to your computer and use it in GitHub Desktop.
Attention-free transformer
"""
Implementation of "An Attention-Free Transformer": https://arxiv.org/abs/2105.14103
"""
import contextlib
import functools
import torch
import triton
import triton.language as tl
class AFTFull(torch.nn.Module):
def __init__(self, embedding_dim: int, hidden_dim: int, max_seq_len: int, bias: bool = False, eps: float = 1e-8) -> None:
super().__init__()
self.max_seq_len = max_seq_len
self.bias = bias
self.eps = eps
self.to_qkv = torch.nn.Linear(embedding_dim, 3 * hidden_dim, bias=bias)
self.to_out = torch.nn.Linear(hidden_dim, embedding_dim, bias=bias)
self.w = torch.nn.Parameter(torch.empty(max_seq_len, max_seq_len))
def init_weights(self) -> None:
torch.nn.init.xavier_uniform_(self.to_qkv.weight)
torch.nn.init.xavier_uniform_(self.to_out.weight)
torch.nn.init.xavier_uniform_(self.w)
if self.bias:
torch.nn.init.zeros_(self.to_qkv.bias)
torch.nn.init.zeros_(self.to_out.bias)
def _aft_forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
seq_len = q.size(1)
# Terrible memory usage
# k = k[:, None, :, :] # [batch_size, 1, seq_len, hidden_dim]
# v = v[:, None, :, :] # [batch_size, 1, seq_len, hidden_dim]
# w = self.w[:seq_len, :seq_len][None, :, :, None] # [1, seq_len, seq_len, 1]
# q = torch.sigmoid(q)
# s = torch.exp(k + w) # [batch_size, seq_len, seq_len, hidden_dim]
# nr = torch.mul(s, v).sum(dim=2) # [batch_size, seq_len, hidden_dim]
# dr = s.sum(dim=2) + self.eps # [batch_size, seq_len, hidden_dim]
# y = torch.mul(q, nr / dr)
w = self.w[:seq_len, :seq_len][None, :, :]
w_exp = torch.exp(w)
k_exp = torch.exp(k)
kv = torch.mul(k_exp, v)
nr = torch.bmm(w_exp, kv)
dr = torch.bmm(w_exp, k_exp) + self.eps
q = torch.sigmoid(q)
y = torch.mul(q, nr / dr)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
y = self._aft_forward(q, k, v)
o = self.to_out(y)
return o
class AFTSimple(torch.nn.Module):
def __init__(self, embedding_dim: int, hidden_dim: int, bias: bool = False, eps: float = 1e-8) -> None:
super().__init__()
self.bias = bias
self.eps = eps
self.to_qkv = torch.nn.Linear(embedding_dim, 3 * hidden_dim, bias=bias)
self.to_out = torch.nn.Linear(hidden_dim, embedding_dim, bias=bias)
def init_weights(self) -> None:
torch.nn.init.xavier_uniform_(self.to_qkv.weight)
torch.nn.init.xavier_uniform_(self.to_out.weight)
if self.bias:
torch.nn.init.zeros_(self.to_qkv.bias)
torch.nn.init.zeros_(self.to_out.bias)
def _aft_forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# Terrible memory usage
# k = k[:, None, :, :]
# v = v[:, None, :, :]
# q = torch.sigmoid(q)
# s = torch.exp(k) # [batch_size, 1, seq_len, hidden_dim]
# nr = torch.mul(s, v).sum(dim=2) # [batch_size, 1, hidden_dim]
# dr = s.sum(dim=2) + self.eps # [batch_size, 1, hidden_dim]
# y = torch.mul(q, nr / dr)
k_exp = torch.exp(k)
kv = torch.mul(k_exp, v)
nr = kv.sum(dim=1, keepdim=True)
dr = k_exp.sum(dim=1, keepdim=True) + self.eps
q = torch.sigmoid(q)
y = torch.mul(q, nr / dr)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
y = self._aft_forward(q, k, v)
o = self.to_out(y)
return o
class AFTLocal(torch.nn.Module):
def __init__(self, embedding_dim: int, hidden_dim: int, max_seq_len: int, window_seq_len: int, bias: bool = False, eps: float = 1e-8) -> None:
super().__init__()
self.max_seq_len = max_seq_len
self.window_seq_len = window_seq_len
self.bias = bias
self.eps = eps
self.to_qkv = torch.nn.Linear(embedding_dim, 3 * hidden_dim, bias=bias)
self.to_out = torch.nn.Linear(hidden_dim, embedding_dim, bias=bias)
self.w = torch.nn.Parameter(torch.empty(max_seq_len, max_seq_len))
indices = torch.arange(self.max_seq_len)
mask = torch.abs(indices[:, None] - indices[None, :]) >= window_seq_len
self.register_buffer("_mask", mask, persistent=False)
def init_weights(self) -> None:
torch.nn.init.xavier_uniform_(self.to_qkv.weight)
torch.nn.init.xavier_uniform_(self.to_out.weight)
torch.nn.init.xavier_uniform_(self.w)
if self.bias:
torch.nn.init.zeros_(self.to_qkv.bias)
torch.nn.init.zeros_(self.to_out.bias)
def _aft_forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
seq_len = q.size(1)
# Terrible memory usage
# k = k[:, None, :, :]
# v = v[:, None, :, :]
# mask = self._prepare_mask(seq_len, self.window_seq_len, q.device)
# w = self.w[:seq_len, :seq_len].clone()
# w[mask] = float("-inf")
# w = w[None, :, :, None]
# q = torch.sigmoid(q)
# s = torch.exp(k + w)
# nr = torch.mul(s, v).sum(dim=2)
# dr = s.sum(dim=2) + self.eps
# y = torch.mul(q, nr / dr)
mask = self._mask[:seq_len, :seq_len]
w = self.w[:seq_len, :seq_len].masked_fill(mask, 0)[None, :, :]
w_exp = torch.exp(w)
k_exp = torch.exp(k)
kv = torch.mul(k_exp, v)
nr = torch.bmm(w_exp, kv)
dr = torch.bmm(w_exp, k_exp) + self.eps
q = torch.sigmoid(q)
y = torch.mul(q, nr / dr)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
y = self._aft_forward(q, k, v)
o = self.to_out(y)
return o
if __name__ == "__main__":
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(42)
max_seq_len = 8192
embedding_dim = 512
hidden_dim = 512
test_seq_len = 1024
model_aft_full = AFTFull(embedding_dim, hidden_dim, max_seq_len).to(device=device, dtype=dtype)
model_aft_full.init_weights()
model_aft_full.eval()
model_aft_simple = AFTSimple(embedding_dim, hidden_dim).to(device=device, dtype=dtype)
model_aft_simple.init_weights()
model_aft_simple.eval()
model_aft_local = AFTLocal(embedding_dim, hidden_dim, max_seq_len, window_seq_len=128).to(device=device, dtype=dtype)
model_aft_local.init_weights()
model_aft_local.eval()
with torch.inference_mode():
x = torch.randn(1, test_seq_len, embedding_dim, device=device, dtype=dtype)
output = model_aft_full(x)
assert output.shape == (1, test_seq_len, embedding_dim), "Output shape mismatch for AFTFull"
output = model_aft_simple(x)
assert output.shape == (1, test_seq_len, embedding_dim), "Output shape mismatch for AFTSimple"
output = model_aft_local(x)
assert output.shape == (1, test_seq_len, embedding_dim), "Output shape mismatch for AFTLocal"
@contextlib.contextmanager
def _aft_method_replace(model: AFTFull | AFTLocal | AFTSimple, method_name: str, new_method: callable):
new_method = functools.partial(new_method, model)
original_method = getattr(model, method_name)
setattr(model, method_name, new_method)
try:
yield
finally:
setattr(model, method_name, original_method)
def get_pointwise_exp_autotune_configs():
configs = []
for BLOCK_N in [512, 1024, 2048]:
for num_warps in [4, 8]:
for num_stages in [1]:
configs.append(triton.Config({"BLOCK_N": BLOCK_N}, num_warps=num_warps, num_stages=num_stages))
return configs
@triton.autotune(configs=get_pointwise_exp_autotune_configs(), key=["seq_len"])
@triton.jit
def _pointwise_exp(w_ptr, w_stride, out_ptr, seq_len: int, BLOCK_N: tl.constexpr):
pid = tl.program_id(0)
start = pid * BLOCK_N
offsets = start + tl.arange(0, BLOCK_N)
mask = offsets < (seq_len * seq_len)
row_offsets = offsets // seq_len
col_offsets = offsets % seq_len
w = tl.load(w_ptr + row_offsets * w_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32)
w_exp = tl.exp(w)
tl.store(out_ptr + offsets, w_exp, mask)
def get_pointwise_mul_kexp_v_autotune_configs():
configs = []
for BLOCK_N in [512, 1024, 2048]:
for num_warps in [4, 8]:
for num_stages in [1]:
configs.append(triton.Config({"BLOCK_N": BLOCK_N}, num_warps=num_warps, num_stages=num_stages))
return configs
@triton.autotune(configs=get_pointwise_mul_kexp_v_autotune_configs(), key=["batch_size", "seq_len", "dim"])
@triton.jit
def _pointwise_mul_kexp_v(k_ptr, k_stride, v_ptr, v_stride, k_exp_ptr, kv_ptr, batch_size, seq_len, dim, BLOCK_N: tl.constexpr):
pid = tl.program_id(0)
start = pid * BLOCK_N
offsets = start + tl.arange(0, BLOCK_N)
mask = offsets < (batch_size * seq_len * dim)
row_offsets = offsets // dim
col_offsets = offsets % dim
k = tl.load(k_ptr + row_offsets * k_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32)
v = tl.load(v_ptr + row_offsets * v_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32)
k_exp = tl.exp(k)
kv = k_exp * v
tl.store(k_exp_ptr + offsets, k_exp, mask=mask)
tl.store(kv_ptr + offsets, kv, mask=mask)
def get_pointwise_mul_qsigmoid_nr_dr_autotune_configs():
configs = []
for BLOCK_N in [512, 1024, 2048]:
for num_warps in [4, 8]:
for num_stages in [1]:
configs.append(triton.Config({"BLOCK_N": BLOCK_N}, num_warps=num_warps, num_stages=num_stages))
return configs
@triton.autotune(configs=get_pointwise_mul_qsigmoid_nr_dr_autotune_configs(), key=["batch_size", "seq_len", "dim"])
@triton.jit
def _pointwise_mul_qsigmoid_nr_dr(
q_ptr, q_stride, nr_ptr, nr_stride, dr_ptr, dr_stride, out_ptr, batch_size, seq_len, dim, eps: tl.constexpr, BLOCK_N: tl.constexpr
):
pid = tl.program_id(0)
start = pid * BLOCK_N
offsets = start + tl.arange(0, BLOCK_N)
mask = offsets < (batch_size * seq_len * dim)
row_offsets = offsets // dim
col_offsets = offsets % dim
q = tl.load(q_ptr + row_offsets * q_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32)
nr = tl.load(nr_ptr + row_offsets * nr_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32)
dr = tl.load(dr_ptr + row_offsets * dr_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32)
q = tl.sigmoid(q)
nr_dr = nr / (dr + eps)
y = q * nr_dr
tl.store(out_ptr + offsets, y, mask=mask)
def _aft_full_fwd(model: AFTFull, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = q.size()
assert batch_size == 1, "Only batch size 1 is supported in this implementation."
w = model.w
w_exp = q.new_empty((seq_len, seq_len))
grid = lambda META: (triton.cdiv(seq_len * seq_len, META["BLOCK_N"]),)
_pointwise_exp[grid](w, w.stride(0), w_exp, seq_len)
w_exp = w_exp[None, :, :]
k_exp = q.new_empty((batch_size, seq_len, dim))
kv = q.new_empty((batch_size, seq_len, dim))
grid = lambda META: (triton.cdiv(batch_size * seq_len * dim, META["BLOCK_N"]),)
# We pass stride(1) because batch size is 1 (other batch size not supported here)
_pointwise_mul_kexp_v[grid](k, k.stride(1), v, v.stride(1), k_exp, kv, batch_size, seq_len, dim)
nr = torch.bmm(w_exp, kv)
dr = torch.bmm(w_exp, k_exp)
y = q.new_empty((batch_size, seq_len, dim))
grid = lambda META: (triton.cdiv(batch_size * seq_len * dim, META["BLOCK_N"]),)
_pointwise_mul_qsigmoid_nr_dr[grid](q, q.stride(1), nr, nr.stride(1), dr, dr.stride(1), y, batch_size, seq_len, dim, model.eps)
return y
def solution_torch(model: torch.nn.Module, tensor: torch.Tensor) -> torch.Tensor:
return model(tensor)
def solution_triton_aft_full(model: torch.nn.Module, tensor: torch.Tensor) -> torch.Tensor:
with _aft_method_replace(model, "_aft_forward", _aft_full_fwd):
return model(tensor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment