Created
July 11, 2025 11:29
-
-
Save a-r-r-o-w/4cab6a83ee580376fbab4492986ad49d to your computer and use it in GitHub Desktop.
Attention-free transformer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Implementation of "An Attention-Free Transformer": https://arxiv.org/abs/2105.14103 | |
| """ | |
| import contextlib | |
| import functools | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| class AFTFull(torch.nn.Module): | |
| def __init__(self, embedding_dim: int, hidden_dim: int, max_seq_len: int, bias: bool = False, eps: float = 1e-8) -> None: | |
| super().__init__() | |
| self.max_seq_len = max_seq_len | |
| self.bias = bias | |
| self.eps = eps | |
| self.to_qkv = torch.nn.Linear(embedding_dim, 3 * hidden_dim, bias=bias) | |
| self.to_out = torch.nn.Linear(hidden_dim, embedding_dim, bias=bias) | |
| self.w = torch.nn.Parameter(torch.empty(max_seq_len, max_seq_len)) | |
| def init_weights(self) -> None: | |
| torch.nn.init.xavier_uniform_(self.to_qkv.weight) | |
| torch.nn.init.xavier_uniform_(self.to_out.weight) | |
| torch.nn.init.xavier_uniform_(self.w) | |
| if self.bias: | |
| torch.nn.init.zeros_(self.to_qkv.bias) | |
| torch.nn.init.zeros_(self.to_out.bias) | |
| def _aft_forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
| seq_len = q.size(1) | |
| # Terrible memory usage | |
| # k = k[:, None, :, :] # [batch_size, 1, seq_len, hidden_dim] | |
| # v = v[:, None, :, :] # [batch_size, 1, seq_len, hidden_dim] | |
| # w = self.w[:seq_len, :seq_len][None, :, :, None] # [1, seq_len, seq_len, 1] | |
| # q = torch.sigmoid(q) | |
| # s = torch.exp(k + w) # [batch_size, seq_len, seq_len, hidden_dim] | |
| # nr = torch.mul(s, v).sum(dim=2) # [batch_size, seq_len, hidden_dim] | |
| # dr = s.sum(dim=2) + self.eps # [batch_size, seq_len, hidden_dim] | |
| # y = torch.mul(q, nr / dr) | |
| w = self.w[:seq_len, :seq_len][None, :, :] | |
| w_exp = torch.exp(w) | |
| k_exp = torch.exp(k) | |
| kv = torch.mul(k_exp, v) | |
| nr = torch.bmm(w_exp, kv) | |
| dr = torch.bmm(w_exp, k_exp) + self.eps | |
| q = torch.sigmoid(q) | |
| y = torch.mul(q, nr / dr) | |
| return y | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| q, k, v = self.to_qkv(x).chunk(3, dim=-1) | |
| y = self._aft_forward(q, k, v) | |
| o = self.to_out(y) | |
| return o | |
| class AFTSimple(torch.nn.Module): | |
| def __init__(self, embedding_dim: int, hidden_dim: int, bias: bool = False, eps: float = 1e-8) -> None: | |
| super().__init__() | |
| self.bias = bias | |
| self.eps = eps | |
| self.to_qkv = torch.nn.Linear(embedding_dim, 3 * hidden_dim, bias=bias) | |
| self.to_out = torch.nn.Linear(hidden_dim, embedding_dim, bias=bias) | |
| def init_weights(self) -> None: | |
| torch.nn.init.xavier_uniform_(self.to_qkv.weight) | |
| torch.nn.init.xavier_uniform_(self.to_out.weight) | |
| if self.bias: | |
| torch.nn.init.zeros_(self.to_qkv.bias) | |
| torch.nn.init.zeros_(self.to_out.bias) | |
| def _aft_forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
| # Terrible memory usage | |
| # k = k[:, None, :, :] | |
| # v = v[:, None, :, :] | |
| # q = torch.sigmoid(q) | |
| # s = torch.exp(k) # [batch_size, 1, seq_len, hidden_dim] | |
| # nr = torch.mul(s, v).sum(dim=2) # [batch_size, 1, hidden_dim] | |
| # dr = s.sum(dim=2) + self.eps # [batch_size, 1, hidden_dim] | |
| # y = torch.mul(q, nr / dr) | |
| k_exp = torch.exp(k) | |
| kv = torch.mul(k_exp, v) | |
| nr = kv.sum(dim=1, keepdim=True) | |
| dr = k_exp.sum(dim=1, keepdim=True) + self.eps | |
| q = torch.sigmoid(q) | |
| y = torch.mul(q, nr / dr) | |
| return y | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| q, k, v = self.to_qkv(x).chunk(3, dim=-1) | |
| y = self._aft_forward(q, k, v) | |
| o = self.to_out(y) | |
| return o | |
| class AFTLocal(torch.nn.Module): | |
| def __init__(self, embedding_dim: int, hidden_dim: int, max_seq_len: int, window_seq_len: int, bias: bool = False, eps: float = 1e-8) -> None: | |
| super().__init__() | |
| self.max_seq_len = max_seq_len | |
| self.window_seq_len = window_seq_len | |
| self.bias = bias | |
| self.eps = eps | |
| self.to_qkv = torch.nn.Linear(embedding_dim, 3 * hidden_dim, bias=bias) | |
| self.to_out = torch.nn.Linear(hidden_dim, embedding_dim, bias=bias) | |
| self.w = torch.nn.Parameter(torch.empty(max_seq_len, max_seq_len)) | |
| indices = torch.arange(self.max_seq_len) | |
| mask = torch.abs(indices[:, None] - indices[None, :]) >= window_seq_len | |
| self.register_buffer("_mask", mask, persistent=False) | |
| def init_weights(self) -> None: | |
| torch.nn.init.xavier_uniform_(self.to_qkv.weight) | |
| torch.nn.init.xavier_uniform_(self.to_out.weight) | |
| torch.nn.init.xavier_uniform_(self.w) | |
| if self.bias: | |
| torch.nn.init.zeros_(self.to_qkv.bias) | |
| torch.nn.init.zeros_(self.to_out.bias) | |
| def _aft_forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
| seq_len = q.size(1) | |
| # Terrible memory usage | |
| # k = k[:, None, :, :] | |
| # v = v[:, None, :, :] | |
| # mask = self._prepare_mask(seq_len, self.window_seq_len, q.device) | |
| # w = self.w[:seq_len, :seq_len].clone() | |
| # w[mask] = float("-inf") | |
| # w = w[None, :, :, None] | |
| # q = torch.sigmoid(q) | |
| # s = torch.exp(k + w) | |
| # nr = torch.mul(s, v).sum(dim=2) | |
| # dr = s.sum(dim=2) + self.eps | |
| # y = torch.mul(q, nr / dr) | |
| mask = self._mask[:seq_len, :seq_len] | |
| w = self.w[:seq_len, :seq_len].masked_fill(mask, 0)[None, :, :] | |
| w_exp = torch.exp(w) | |
| k_exp = torch.exp(k) | |
| kv = torch.mul(k_exp, v) | |
| nr = torch.bmm(w_exp, kv) | |
| dr = torch.bmm(w_exp, k_exp) + self.eps | |
| q = torch.sigmoid(q) | |
| y = torch.mul(q, nr / dr) | |
| return y | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| q, k, v = self.to_qkv(x).chunk(3, dim=-1) | |
| y = self._aft_forward(q, k, v) | |
| o = self.to_out(y) | |
| return o | |
| if __name__ == "__main__": | |
| device = torch.device("cuda") | |
| dtype = torch.bfloat16 | |
| torch.manual_seed(42) | |
| max_seq_len = 8192 | |
| embedding_dim = 512 | |
| hidden_dim = 512 | |
| test_seq_len = 1024 | |
| model_aft_full = AFTFull(embedding_dim, hidden_dim, max_seq_len).to(device=device, dtype=dtype) | |
| model_aft_full.init_weights() | |
| model_aft_full.eval() | |
| model_aft_simple = AFTSimple(embedding_dim, hidden_dim).to(device=device, dtype=dtype) | |
| model_aft_simple.init_weights() | |
| model_aft_simple.eval() | |
| model_aft_local = AFTLocal(embedding_dim, hidden_dim, max_seq_len, window_seq_len=128).to(device=device, dtype=dtype) | |
| model_aft_local.init_weights() | |
| model_aft_local.eval() | |
| with torch.inference_mode(): | |
| x = torch.randn(1, test_seq_len, embedding_dim, device=device, dtype=dtype) | |
| output = model_aft_full(x) | |
| assert output.shape == (1, test_seq_len, embedding_dim), "Output shape mismatch for AFTFull" | |
| output = model_aft_simple(x) | |
| assert output.shape == (1, test_seq_len, embedding_dim), "Output shape mismatch for AFTSimple" | |
| output = model_aft_local(x) | |
| assert output.shape == (1, test_seq_len, embedding_dim), "Output shape mismatch for AFTLocal" | |
| @contextlib.contextmanager | |
| def _aft_method_replace(model: AFTFull | AFTLocal | AFTSimple, method_name: str, new_method: callable): | |
| new_method = functools.partial(new_method, model) | |
| original_method = getattr(model, method_name) | |
| setattr(model, method_name, new_method) | |
| try: | |
| yield | |
| finally: | |
| setattr(model, method_name, original_method) | |
| def get_pointwise_exp_autotune_configs(): | |
| configs = [] | |
| for BLOCK_N in [512, 1024, 2048]: | |
| for num_warps in [4, 8]: | |
| for num_stages in [1]: | |
| configs.append(triton.Config({"BLOCK_N": BLOCK_N}, num_warps=num_warps, num_stages=num_stages)) | |
| return configs | |
| @triton.autotune(configs=get_pointwise_exp_autotune_configs(), key=["seq_len"]) | |
| @triton.jit | |
| def _pointwise_exp(w_ptr, w_stride, out_ptr, seq_len: int, BLOCK_N: tl.constexpr): | |
| pid = tl.program_id(0) | |
| start = pid * BLOCK_N | |
| offsets = start + tl.arange(0, BLOCK_N) | |
| mask = offsets < (seq_len * seq_len) | |
| row_offsets = offsets // seq_len | |
| col_offsets = offsets % seq_len | |
| w = tl.load(w_ptr + row_offsets * w_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32) | |
| w_exp = tl.exp(w) | |
| tl.store(out_ptr + offsets, w_exp, mask) | |
| def get_pointwise_mul_kexp_v_autotune_configs(): | |
| configs = [] | |
| for BLOCK_N in [512, 1024, 2048]: | |
| for num_warps in [4, 8]: | |
| for num_stages in [1]: | |
| configs.append(triton.Config({"BLOCK_N": BLOCK_N}, num_warps=num_warps, num_stages=num_stages)) | |
| return configs | |
| @triton.autotune(configs=get_pointwise_mul_kexp_v_autotune_configs(), key=["batch_size", "seq_len", "dim"]) | |
| @triton.jit | |
| def _pointwise_mul_kexp_v(k_ptr, k_stride, v_ptr, v_stride, k_exp_ptr, kv_ptr, batch_size, seq_len, dim, BLOCK_N: tl.constexpr): | |
| pid = tl.program_id(0) | |
| start = pid * BLOCK_N | |
| offsets = start + tl.arange(0, BLOCK_N) | |
| mask = offsets < (batch_size * seq_len * dim) | |
| row_offsets = offsets // dim | |
| col_offsets = offsets % dim | |
| k = tl.load(k_ptr + row_offsets * k_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32) | |
| v = tl.load(v_ptr + row_offsets * v_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32) | |
| k_exp = tl.exp(k) | |
| kv = k_exp * v | |
| tl.store(k_exp_ptr + offsets, k_exp, mask=mask) | |
| tl.store(kv_ptr + offsets, kv, mask=mask) | |
| def get_pointwise_mul_qsigmoid_nr_dr_autotune_configs(): | |
| configs = [] | |
| for BLOCK_N in [512, 1024, 2048]: | |
| for num_warps in [4, 8]: | |
| for num_stages in [1]: | |
| configs.append(triton.Config({"BLOCK_N": BLOCK_N}, num_warps=num_warps, num_stages=num_stages)) | |
| return configs | |
| @triton.autotune(configs=get_pointwise_mul_qsigmoid_nr_dr_autotune_configs(), key=["batch_size", "seq_len", "dim"]) | |
| @triton.jit | |
| def _pointwise_mul_qsigmoid_nr_dr( | |
| q_ptr, q_stride, nr_ptr, nr_stride, dr_ptr, dr_stride, out_ptr, batch_size, seq_len, dim, eps: tl.constexpr, BLOCK_N: tl.constexpr | |
| ): | |
| pid = tl.program_id(0) | |
| start = pid * BLOCK_N | |
| offsets = start + tl.arange(0, BLOCK_N) | |
| mask = offsets < (batch_size * seq_len * dim) | |
| row_offsets = offsets // dim | |
| col_offsets = offsets % dim | |
| q = tl.load(q_ptr + row_offsets * q_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32) | |
| nr = tl.load(nr_ptr + row_offsets * nr_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32) | |
| dr = tl.load(dr_ptr + row_offsets * dr_stride + col_offsets, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32) | |
| q = tl.sigmoid(q) | |
| nr_dr = nr / (dr + eps) | |
| y = q * nr_dr | |
| tl.store(out_ptr + offsets, y, mask=mask) | |
| def _aft_full_fwd(model: AFTFull, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
| batch_size, seq_len, dim = q.size() | |
| assert batch_size == 1, "Only batch size 1 is supported in this implementation." | |
| w = model.w | |
| w_exp = q.new_empty((seq_len, seq_len)) | |
| grid = lambda META: (triton.cdiv(seq_len * seq_len, META["BLOCK_N"]),) | |
| _pointwise_exp[grid](w, w.stride(0), w_exp, seq_len) | |
| w_exp = w_exp[None, :, :] | |
| k_exp = q.new_empty((batch_size, seq_len, dim)) | |
| kv = q.new_empty((batch_size, seq_len, dim)) | |
| grid = lambda META: (triton.cdiv(batch_size * seq_len * dim, META["BLOCK_N"]),) | |
| # We pass stride(1) because batch size is 1 (other batch size not supported here) | |
| _pointwise_mul_kexp_v[grid](k, k.stride(1), v, v.stride(1), k_exp, kv, batch_size, seq_len, dim) | |
| nr = torch.bmm(w_exp, kv) | |
| dr = torch.bmm(w_exp, k_exp) | |
| y = q.new_empty((batch_size, seq_len, dim)) | |
| grid = lambda META: (triton.cdiv(batch_size * seq_len * dim, META["BLOCK_N"]),) | |
| _pointwise_mul_qsigmoid_nr_dr[grid](q, q.stride(1), nr, nr.stride(1), dr, dr.stride(1), y, batch_size, seq_len, dim, model.eps) | |
| return y | |
| def solution_torch(model: torch.nn.Module, tensor: torch.Tensor) -> torch.Tensor: | |
| return model(tensor) | |
| def solution_triton_aft_full(model: torch.nn.Module, tensor: torch.Tensor) -> torch.Tensor: | |
| with _aft_method_replace(model, "_aft_forward", _aft_full_fwd): | |
| return model(tensor) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment