Skip to content

Instantly share code, notes, and snippets.

@gabrielhuang
Created June 19, 2025 21:02
Show Gist options
  • Select an option

  • Save gabrielhuang/9b57ec4ffd01b88e88f98147f177a0af to your computer and use it in GitHub Desktop.

Select an option

Save gabrielhuang/9b57ec4ffd01b88e88f98147f177a0af to your computer and use it in GitHub Desktop.
toy transformer implementing a character-level MQA/MHA with Rope embeddings
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules import transformer
embedding_size = 16
vocab = "abcdefghijklmnopqrstuvwxyz .!?'\"\n$" # $ is also EOS
vocab_size = len(vocab)
def tokenize(x: str) -> torch.Tensor:
x = x.lower()
x = ''.join([ch for ch in x if ch in vocab])
# print(f"Preprocessing x to {x}")
indices = [vocab.index(ch) for ch in x]
return torch.LongTensor([indices])
input_text = """Once upon a time, there was a chicken named Kevin.
Kevin lived on a quiet farm with no internet.
One day, Kevin saw a pigeon texting.
He was amazed.
"Where did you get Wi-Fi?" he asked.
The pigeon replied, "I hacked the farmer's router."
Kevin was inspired.
He pecked the barn door open.
He waddled into the farmhouse.
He found the router behind a toaster."""
embeddings = torch.arange(vocab_size)[:, None] + torch.arange(embedding_size)[None, :] / 10#torch.randn(vocab_size, embedding_size)
input_sequence = torch.asarray([2, 4, 3], dtype=torch.long)[None, :]
# input_embeddings = embeddings[input_sequence][None, :, :]
# ignore the positional encodings for now
# let's do the forward pass
import math
class MHA(nn.Module):
def __init__(self, embedding_size=embedding_size, heads=3, multiquery=True):
super().__init__()
self.q_network = nn.Linear(embedding_size, embedding_size*heads)
if not multiquery:
self.k_network = nn.Linear(embedding_size, embedding_size*heads)
self.v_network = nn.Linear(embedding_size, embedding_size*heads)
else:
self.k_network = nn.Linear(embedding_size, embedding_size)
self.v_network = nn.Linear(embedding_size, embedding_size)
self.projection = nn.Linear(heads*embedding_size, embedding_size)
self.heads = heads
self.multiquery = multiquery
def apply_rope(self, x):
# x: (batch, head, seq_len, head_dim)
seq_len = x.size(2)
half_dim = x.size(-1) // 2
device = x.device
freq = torch.exp(
-torch.arange(0, half_dim, dtype=torch.float32, device=device)
* math.log(10000) / half_dim
) # (half_dim,)
pos = torch.arange(seq_len, dtype=torch.float32, device=device) # (seq_len,)
sinusoid = torch.einsum("i,j->ij", pos, freq) # (seq_len, half_dim)
sin = sinusoid.sin()[None, None, :, :] # (1, 1, seq_len, half_dim)
cos = sinusoid.cos()[None, None, :, :] # (1, 1, seq_len, half_dim)
x1, x2 = x[..., :half_dim], x[..., half_dim:]
x_rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
return x_rotated
def forward(self, x):
# x.size() == (batch, sequence, embedding)x
q = self.q_network(x).reshape((x.size(0), x.size(1), self.heads, x.size(2))).transpose(1, 2) # (batch, head, seq, emb)
if not self.multiquery:
k = self.k_network(x).reshape((x.size(0), x.size(1), self.heads, x.size(2))).transpose(1, 2) # (batch, head, seq, emb)
v = self.v_network(x).reshape((x.size(0), x.size(1), self.heads, x.size(2))).transpose(1, 2) # (batch, head, seq, emb)
else:
k = self.k_network(x).reshape((x.size(0), x.size(1), 1, x.size(2))).transpose(1, 2) # (batch, head, seq, emb)
v = self.v_network(x).reshape((x.size(0), x.size(1), 1, x.size(2))).transpose(1, 2) # (batch, head, seq, emb)
q = self.apply_rope(q)
k = self.apply_rope(k)
# Do masked dot product between k and v
seq_len = x.size(1)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask[mask==1] = float('inf')
weights = (q[:, :, :, None, :] * k[:, :, None, :, :] / math.sqrt(q.size(-1))).sum(-1) - mask[None, None, :, :]
norm_weights = F.softmax(weights, dim=-1)
weighted_v = torch.matmul(norm_weights, v)
weighted_v = weighted_v.permute((0,2,1,3))
weighted_v = weighted_v.reshape((weighted_v.size(0), weighted_v.size(1), -1))
combined = self.projection(weighted_v)
return combined
class FF(torch.nn.Module):
def __init__(self, embedding_size=embedding_size):
super().__init__()
self.linear1 = nn.Linear(embedding_size, embedding_size)
self.linear2 = nn.Linear(embedding_size, embedding_size)
def forward(self, x):
return self.linear2(F.relu(self.linear1(x)))
class Layer(torch.nn.Module):
def __init__(self, embedding_size=embedding_size):
super().__init__()
self.mha = MHA(embedding_size)
self.norm1 = nn.LayerNorm(embedding_size)
self.norm2 = nn.LayerNorm(embedding_size)
self.ff = FF(embedding_size)
def forward(self, x):
# x.size() == (batch, sequence, embedding)x
weighted_v = self.mha(x)
x = x + weighted_v
x = self.norm1(x)
ff_output = self.ff(x)
x = x + ff_output
x = self.norm2(x)
return x
class Transformer(torch.nn.Module):
def __init__(self, embedding_size=embedding_size, layers=3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.layers = nn.ModuleList([Layer(embedding_size) for i in range(layers)])
self.final_linear = nn.Linear(embedding_size, vocab_size)
def forward(self, x):
x = self.embedding(x)
for layer_idx, layer in enumerate(self.layers):
x = layer(x)
x = self.final_linear(x)
# x = F.softmax(x, dim=-1)
return x
input_text = input_text[:120]
input_sequence = tokenize(input_text)
target = torch.concat((input_sequence[:,1:], torch.LongTensor([[len(vocab)-1]])), -1)
transformer = Transformer()
y = transformer(input_sequence)
from torch.optim.adamw import AdamW
optimizer = AdamW(transformer.parameters(), lr=1e-2)
for i in range(50):
print(f'Iteration {i}')
output = transformer(input_sequence)
loss = F.cross_entropy(output.reshape((-1, output.size(-1))), target.flatten())
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Loss {loss:.2f}')
def generate(starter: str='o'):
current = starter
for i in range(400):
input_sequence = tokenize(current)
output = transformer(input_sequence)
new_letter = vocab[output[0, -1].argmax()]
print(new_letter, flush=True, end="")
current += new_letter
generate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment