Skip to content

Instantly share code, notes, and snippets.

@kroggen
Created April 24, 2025 21:33
Show Gist options
  • Save kroggen/269ea314e7df8f895c9a213bec1eec18 to your computer and use it in GitHub Desktop.
Save kroggen/269ea314e7df8f895c9a213bec1eec18 to your computer and use it in GitHub Desktop.
Use TokenFormer to train a small NN for MNIST digit recognition
from tinygrad import Tensor, nn, Device, TinyJit
from tinygrad.nn.datasets import mnist
from tinygrad.nn.state import safe_save, get_state_dict
import math
print(f"Using device: {Device.DEFAULT}")
# normalization for Pattention
def nonlinear_normalization(inputs, normalization_type, dim=-1):
if normalization_type == 'softmax':
outputs = inputs.softmax(axis=dim) # * math.sqrt(inputs.shape[dim])
elif normalization_type == 'scaled_softmax': # note: this one works better without growth
scale = 1.0 / math.sqrt(inputs.shape[dim])
outputs = (inputs * scale).softmax(axis=dim)
elif normalization_type == 'scaled_softmax2':
scale = 1.0 / math.sqrt(inputs.shape[dim])
inputs = inputs * scale
max_val = inputs.max(axis=dim, keepdim=True)
exp_inputs = (inputs - max_val).exp()
outputs = exp_inputs / exp_inputs.sum(axis=dim, keepdim=True)
elif normalization_type == 'l1_norm':
norm = inputs.abs().sum(axis=dim, keepdim=True)
outputs = inputs / norm * math.sqrt(inputs.shape[dim])
elif normalization_type == 'l2_norm':
norm = (inputs ** 2).sum(axis=dim, keepdim=True).sqrt()
outputs = inputs / norm * math.sqrt(inputs.shape[dim])
elif normalization_type == 'gelu_l2_norm':
nonlinear_outputs = inputs.gelu()
norm = (nonlinear_outputs ** 2).sum(axis=dim, keepdim=True).sqrt()
outputs = nonlinear_outputs / norm * math.sqrt(inputs.shape[dim])
elif normalization_type == 'l2_norm_gelu':
norm = (inputs ** 2).sum(axis=dim, keepdim=True).sqrt()
norm_outputs = inputs / norm * math.sqrt(inputs.shape[dim])
outputs = norm_outputs.gelu()
else:
raise NotImplementedError
return outputs
class Pattention:
def __init__(self, input_channels, output_channels, token_num, normalization_type):
self.input_channels = input_channels
self.output_channels = output_channels
self.normalization_type = normalization_type
# Initialize with small random values and enable gradients
self.key_param_tokens = Tensor.randn(token_num, input_channels) * 0.02
self.key_param_tokens.requires_grad = True
self.value_param_tokens = Tensor.randn(token_num, output_channels) * 0.00001
self.value_param_tokens.requires_grad = True
def grow_parameters(self, num_to_add):
# Create new parameters
new_keys = Tensor.randn(num_to_add, self.input_channels, requires_grad=True) * 0.02
new_values = Tensor.randn(num_to_add, self.output_channels, requires_grad=True) * 0.00001
# Concatenate while preserving gradients
self.key_param_tokens = Tensor.cat(self.key_param_tokens, new_keys, dim=0)
self.value_param_tokens = Tensor.cat(self.value_param_tokens, new_values, dim=0)
def __call__(self, inputs):
attn_weights = inputs @ self.key_param_tokens.transpose()
attn_weights = nonlinear_normalization(attn_weights, self.normalization_type)
output = attn_weights @ self.value_param_tokens
return output
# Define the model
class Model:
def __init__(self):
self.l1 = nn.Conv2d(1, 32, kernel_size=(3,3))
self.l2 = nn.Conv2d(32, 64, kernel_size=(3,3))
# Replace Linear with Pattention
self.pattention = Pattention(input_channels=1600, output_channels=10,
token_num=4, normalization_type='l2_norm')
def __call__(self, x:Tensor) -> Tensor:
x = self.l1(x).relu().max_pool2d((2,2))
x = self.l2(x).relu().max_pool2d((2,2))
# Remove dropout and use Pattention instead of Linear
return self.pattention(x.flatten(1))
# Move all training code inside a main block
if __name__ == "__main__":
# Load dataset
X_train, Y_train, X_test, Y_test = mnist()
# Initialize model and optimizer
model = Model()
optim = nn.optim.Adam(nn.state.get_parameters(model))
batch_size = 256
# Define training step
def step():
Tensor.training = True
samples = Tensor.randint(batch_size, high=X_train.shape[0])
X, Y = X_train[samples], Y_train[samples]
optim.zero_grad()
loss = model(X).sparse_categorical_crossentropy(Y).backward()
optim.step()
return loss
# JIT compile the step function
jit_step = TinyJit(step)
# Configuration for growing parameters
GROWTH_START = 0 # Start growing after N steps
GROWTH_INTERVAL = 50 # Grow parameters every N steps
GROWTH_RATE = 2
MAX_TOKENS = 64
# Training loop
for step_num in range(7000):
loss = jit_step()
# Grow parameters periodically
if step_num > GROWTH_START and (step_num - GROWTH_START) % GROWTH_INTERVAL == 0 and model.pattention.key_param_tokens.shape[0] < MAX_TOKENS:
model.pattention.grow_parameters(GROWTH_RATE)
# Reinitialize the optimizer with the new parameters
optim = nn.optim.Adam(nn.state.get_parameters(model))
# Need to recompile after changing parameter shapes
jit_step = TinyJit(step)
print(f"Growing parameters at step {step_num}. New token count: {model.pattention.key_param_tokens.shape[0]}")
if step_num % 100 == 0:
Tensor.training = False
acc = (model(X_test).argmax(axis=1) == Y_test).mean().item()
print(f"step {step_num:4d}, loss {loss.item():.3f}, acc {acc*100.:.2f}%")
# Save the trained model
state_dict = get_state_dict(model)
safe_save(state_dict, "mnist_model_pattention2.safetensors")
print("Model saved successfully")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment