Created
April 24, 2025 21:33
-
-
Save kroggen/269ea314e7df8f895c9a213bec1eec18 to your computer and use it in GitHub Desktop.
Use TokenFormer to train a small NN for MNIST digit recognition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tinygrad import Tensor, nn, Device, TinyJit | |
from tinygrad.nn.datasets import mnist | |
from tinygrad.nn.state import safe_save, get_state_dict | |
import math | |
print(f"Using device: {Device.DEFAULT}") | |
# normalization for Pattention | |
def nonlinear_normalization(inputs, normalization_type, dim=-1): | |
if normalization_type == 'softmax': | |
outputs = inputs.softmax(axis=dim) # * math.sqrt(inputs.shape[dim]) | |
elif normalization_type == 'scaled_softmax': # note: this one works better without growth | |
scale = 1.0 / math.sqrt(inputs.shape[dim]) | |
outputs = (inputs * scale).softmax(axis=dim) | |
elif normalization_type == 'scaled_softmax2': | |
scale = 1.0 / math.sqrt(inputs.shape[dim]) | |
inputs = inputs * scale | |
max_val = inputs.max(axis=dim, keepdim=True) | |
exp_inputs = (inputs - max_val).exp() | |
outputs = exp_inputs / exp_inputs.sum(axis=dim, keepdim=True) | |
elif normalization_type == 'l1_norm': | |
norm = inputs.abs().sum(axis=dim, keepdim=True) | |
outputs = inputs / norm * math.sqrt(inputs.shape[dim]) | |
elif normalization_type == 'l2_norm': | |
norm = (inputs ** 2).sum(axis=dim, keepdim=True).sqrt() | |
outputs = inputs / norm * math.sqrt(inputs.shape[dim]) | |
elif normalization_type == 'gelu_l2_norm': | |
nonlinear_outputs = inputs.gelu() | |
norm = (nonlinear_outputs ** 2).sum(axis=dim, keepdim=True).sqrt() | |
outputs = nonlinear_outputs / norm * math.sqrt(inputs.shape[dim]) | |
elif normalization_type == 'l2_norm_gelu': | |
norm = (inputs ** 2).sum(axis=dim, keepdim=True).sqrt() | |
norm_outputs = inputs / norm * math.sqrt(inputs.shape[dim]) | |
outputs = norm_outputs.gelu() | |
else: | |
raise NotImplementedError | |
return outputs | |
class Pattention: | |
def __init__(self, input_channels, output_channels, token_num, normalization_type): | |
self.input_channels = input_channels | |
self.output_channels = output_channels | |
self.normalization_type = normalization_type | |
# Initialize with small random values and enable gradients | |
self.key_param_tokens = Tensor.randn(token_num, input_channels) * 0.02 | |
self.key_param_tokens.requires_grad = True | |
self.value_param_tokens = Tensor.randn(token_num, output_channels) * 0.00001 | |
self.value_param_tokens.requires_grad = True | |
def grow_parameters(self, num_to_add): | |
# Create new parameters | |
new_keys = Tensor.randn(num_to_add, self.input_channels, requires_grad=True) * 0.02 | |
new_values = Tensor.randn(num_to_add, self.output_channels, requires_grad=True) * 0.00001 | |
# Concatenate while preserving gradients | |
self.key_param_tokens = Tensor.cat(self.key_param_tokens, new_keys, dim=0) | |
self.value_param_tokens = Tensor.cat(self.value_param_tokens, new_values, dim=0) | |
def __call__(self, inputs): | |
attn_weights = inputs @ self.key_param_tokens.transpose() | |
attn_weights = nonlinear_normalization(attn_weights, self.normalization_type) | |
output = attn_weights @ self.value_param_tokens | |
return output | |
# Define the model | |
class Model: | |
def __init__(self): | |
self.l1 = nn.Conv2d(1, 32, kernel_size=(3,3)) | |
self.l2 = nn.Conv2d(32, 64, kernel_size=(3,3)) | |
# Replace Linear with Pattention | |
self.pattention = Pattention(input_channels=1600, output_channels=10, | |
token_num=4, normalization_type='l2_norm') | |
def __call__(self, x:Tensor) -> Tensor: | |
x = self.l1(x).relu().max_pool2d((2,2)) | |
x = self.l2(x).relu().max_pool2d((2,2)) | |
# Remove dropout and use Pattention instead of Linear | |
return self.pattention(x.flatten(1)) | |
# Move all training code inside a main block | |
if __name__ == "__main__": | |
# Load dataset | |
X_train, Y_train, X_test, Y_test = mnist() | |
# Initialize model and optimizer | |
model = Model() | |
optim = nn.optim.Adam(nn.state.get_parameters(model)) | |
batch_size = 256 | |
# Define training step | |
def step(): | |
Tensor.training = True | |
samples = Tensor.randint(batch_size, high=X_train.shape[0]) | |
X, Y = X_train[samples], Y_train[samples] | |
optim.zero_grad() | |
loss = model(X).sparse_categorical_crossentropy(Y).backward() | |
optim.step() | |
return loss | |
# JIT compile the step function | |
jit_step = TinyJit(step) | |
# Configuration for growing parameters | |
GROWTH_START = 0 # Start growing after N steps | |
GROWTH_INTERVAL = 50 # Grow parameters every N steps | |
GROWTH_RATE = 2 | |
MAX_TOKENS = 64 | |
# Training loop | |
for step_num in range(7000): | |
loss = jit_step() | |
# Grow parameters periodically | |
if step_num > GROWTH_START and (step_num - GROWTH_START) % GROWTH_INTERVAL == 0 and model.pattention.key_param_tokens.shape[0] < MAX_TOKENS: | |
model.pattention.grow_parameters(GROWTH_RATE) | |
# Reinitialize the optimizer with the new parameters | |
optim = nn.optim.Adam(nn.state.get_parameters(model)) | |
# Need to recompile after changing parameter shapes | |
jit_step = TinyJit(step) | |
print(f"Growing parameters at step {step_num}. New token count: {model.pattention.key_param_tokens.shape[0]}") | |
if step_num % 100 == 0: | |
Tensor.training = False | |
acc = (model(X_test).argmax(axis=1) == Y_test).mean().item() | |
print(f"step {step_num:4d}, loss {loss.item():.3f}, acc {acc*100.:.2f}%") | |
# Save the trained model | |
state_dict = get_state_dict(model) | |
safe_save(state_dict, "mnist_model_pattention2.safetensors") | |
print("Model saved successfully") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment