Created
June 19, 2025 21:02
-
-
Save gabrielhuang/9b57ec4ffd01b88e88f98147f177a0af to your computer and use it in GitHub Desktop.
toy transformer implementing a character-level MQA/MHA with Rope embeddings
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.nn.modules import transformer | |
| embedding_size = 16 | |
| vocab = "abcdefghijklmnopqrstuvwxyz .!?'\"\n$" # $ is also EOS | |
| vocab_size = len(vocab) | |
| def tokenize(x: str) -> torch.Tensor: | |
| x = x.lower() | |
| x = ''.join([ch for ch in x if ch in vocab]) | |
| # print(f"Preprocessing x to {x}") | |
| indices = [vocab.index(ch) for ch in x] | |
| return torch.LongTensor([indices]) | |
| input_text = """Once upon a time, there was a chicken named Kevin. | |
| Kevin lived on a quiet farm with no internet. | |
| One day, Kevin saw a pigeon texting. | |
| He was amazed. | |
| "Where did you get Wi-Fi?" he asked. | |
| The pigeon replied, "I hacked the farmer's router." | |
| Kevin was inspired. | |
| He pecked the barn door open. | |
| He waddled into the farmhouse. | |
| He found the router behind a toaster.""" | |
| embeddings = torch.arange(vocab_size)[:, None] + torch.arange(embedding_size)[None, :] / 10#torch.randn(vocab_size, embedding_size) | |
| input_sequence = torch.asarray([2, 4, 3], dtype=torch.long)[None, :] | |
| # input_embeddings = embeddings[input_sequence][None, :, :] | |
| # ignore the positional encodings for now | |
| # let's do the forward pass | |
| import math | |
| class MHA(nn.Module): | |
| def __init__(self, embedding_size=embedding_size, heads=3, multiquery=True): | |
| super().__init__() | |
| self.q_network = nn.Linear(embedding_size, embedding_size*heads) | |
| if not multiquery: | |
| self.k_network = nn.Linear(embedding_size, embedding_size*heads) | |
| self.v_network = nn.Linear(embedding_size, embedding_size*heads) | |
| else: | |
| self.k_network = nn.Linear(embedding_size, embedding_size) | |
| self.v_network = nn.Linear(embedding_size, embedding_size) | |
| self.projection = nn.Linear(heads*embedding_size, embedding_size) | |
| self.heads = heads | |
| self.multiquery = multiquery | |
| def apply_rope(self, x): | |
| # x: (batch, head, seq_len, head_dim) | |
| seq_len = x.size(2) | |
| half_dim = x.size(-1) // 2 | |
| device = x.device | |
| freq = torch.exp( | |
| -torch.arange(0, half_dim, dtype=torch.float32, device=device) | |
| * math.log(10000) / half_dim | |
| ) # (half_dim,) | |
| pos = torch.arange(seq_len, dtype=torch.float32, device=device) # (seq_len,) | |
| sinusoid = torch.einsum("i,j->ij", pos, freq) # (seq_len, half_dim) | |
| sin = sinusoid.sin()[None, None, :, :] # (1, 1, seq_len, half_dim) | |
| cos = sinusoid.cos()[None, None, :, :] # (1, 1, seq_len, half_dim) | |
| x1, x2 = x[..., :half_dim], x[..., half_dim:] | |
| x_rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) | |
| return x_rotated | |
| def forward(self, x): | |
| # x.size() == (batch, sequence, embedding)x | |
| q = self.q_network(x).reshape((x.size(0), x.size(1), self.heads, x.size(2))).transpose(1, 2) # (batch, head, seq, emb) | |
| if not self.multiquery: | |
| k = self.k_network(x).reshape((x.size(0), x.size(1), self.heads, x.size(2))).transpose(1, 2) # (batch, head, seq, emb) | |
| v = self.v_network(x).reshape((x.size(0), x.size(1), self.heads, x.size(2))).transpose(1, 2) # (batch, head, seq, emb) | |
| else: | |
| k = self.k_network(x).reshape((x.size(0), x.size(1), 1, x.size(2))).transpose(1, 2) # (batch, head, seq, emb) | |
| v = self.v_network(x).reshape((x.size(0), x.size(1), 1, x.size(2))).transpose(1, 2) # (batch, head, seq, emb) | |
| q = self.apply_rope(q) | |
| k = self.apply_rope(k) | |
| # Do masked dot product between k and v | |
| seq_len = x.size(1) | |
| mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) | |
| mask[mask==1] = float('inf') | |
| weights = (q[:, :, :, None, :] * k[:, :, None, :, :] / math.sqrt(q.size(-1))).sum(-1) - mask[None, None, :, :] | |
| norm_weights = F.softmax(weights, dim=-1) | |
| weighted_v = torch.matmul(norm_weights, v) | |
| weighted_v = weighted_v.permute((0,2,1,3)) | |
| weighted_v = weighted_v.reshape((weighted_v.size(0), weighted_v.size(1), -1)) | |
| combined = self.projection(weighted_v) | |
| return combined | |
| class FF(torch.nn.Module): | |
| def __init__(self, embedding_size=embedding_size): | |
| super().__init__() | |
| self.linear1 = nn.Linear(embedding_size, embedding_size) | |
| self.linear2 = nn.Linear(embedding_size, embedding_size) | |
| def forward(self, x): | |
| return self.linear2(F.relu(self.linear1(x))) | |
| class Layer(torch.nn.Module): | |
| def __init__(self, embedding_size=embedding_size): | |
| super().__init__() | |
| self.mha = MHA(embedding_size) | |
| self.norm1 = nn.LayerNorm(embedding_size) | |
| self.norm2 = nn.LayerNorm(embedding_size) | |
| self.ff = FF(embedding_size) | |
| def forward(self, x): | |
| # x.size() == (batch, sequence, embedding)x | |
| weighted_v = self.mha(x) | |
| x = x + weighted_v | |
| x = self.norm1(x) | |
| ff_output = self.ff(x) | |
| x = x + ff_output | |
| x = self.norm2(x) | |
| return x | |
| class Transformer(torch.nn.Module): | |
| def __init__(self, embedding_size=embedding_size, layers=3): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_size) | |
| self.layers = nn.ModuleList([Layer(embedding_size) for i in range(layers)]) | |
| self.final_linear = nn.Linear(embedding_size, vocab_size) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| for layer_idx, layer in enumerate(self.layers): | |
| x = layer(x) | |
| x = self.final_linear(x) | |
| # x = F.softmax(x, dim=-1) | |
| return x | |
| input_text = input_text[:120] | |
| input_sequence = tokenize(input_text) | |
| target = torch.concat((input_sequence[:,1:], torch.LongTensor([[len(vocab)-1]])), -1) | |
| transformer = Transformer() | |
| y = transformer(input_sequence) | |
| from torch.optim.adamw import AdamW | |
| optimizer = AdamW(transformer.parameters(), lr=1e-2) | |
| for i in range(50): | |
| print(f'Iteration {i}') | |
| output = transformer(input_sequence) | |
| loss = F.cross_entropy(output.reshape((-1, output.size(-1))), target.flatten()) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| print(f'Loss {loss:.2f}') | |
| def generate(starter: str='o'): | |
| current = starter | |
| for i in range(400): | |
| input_sequence = tokenize(current) | |
| output = transformer(input_sequence) | |
| new_letter = vocab[output[0, -1].argmax()] | |
| print(new_letter, flush=True, end="") | |
| current += new_letter | |
| generate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment