-
-
Save samdobson/975c8b095a71bbdf1488987eacdce144 to your computer and use it in GitHub Desktop.
Claude Code generated script for running nanochat models on CPU on macOS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "torch", | |
| # "tiktoken", | |
| # "numpy", | |
| # ] | |
| # /// | |
| """ | |
| Standalone CPU-compatible text generation for the nanochat model. | |
| This script requires NO nanochat installation - just torch and tiktoken! | |
| Usage: | |
| uv run generate_cpu_standalone.py --model-dir /path/to/model --prompt "Hello" | |
| """ | |
| import argparse | |
| import torch | |
| import os | |
| import json | |
| import glob | |
| import pickle | |
| import math | |
| from dataclasses import dataclass | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ----------------------------------------------------------------------------- | |
| # Minimal GPT implementation (copied from nanochat to make standalone) | |
| @dataclass | |
| class GPTConfig: | |
| sequence_len: int = 1024 | |
| vocab_size: int = 50304 | |
| n_layer: int = 12 | |
| n_head: int = 6 | |
| n_kv_head: int = 6 | |
| n_embd: int = 768 | |
| def norm(x): | |
| return F.rms_norm(x, (x.size(-1),)) | |
| def apply_rotary_emb(x, cos, sin): | |
| assert x.ndim == 4 | |
| d = x.shape[3] // 2 | |
| x1, x2 = x[..., :d], x[..., d:] | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| out = torch.cat([y1, y2], 3) | |
| out = out.to(x.dtype) | |
| return out | |
| def repeat_kv(x, n_rep): | |
| if n_rep == 1: | |
| return x | |
| bs, n_kv_heads, slen, head_dim = x.shape | |
| return ( | |
| x[:, :, None, :, :] | |
| .expand(bs, n_kv_heads, n_rep, slen, head_dim) | |
| .reshape(bs, n_kv_heads * n_rep, slen, head_dim) | |
| ) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, config, layer_idx): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.n_head = config.n_head | |
| self.n_kv_head = config.n_kv_head | |
| self.n_embd = config.n_embd | |
| self.head_dim = self.n_embd // self.n_head | |
| assert self.n_embd % self.n_head == 0 | |
| assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 | |
| self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) | |
| self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) | |
| self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) | |
| self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
| def forward(self, x, cos_sin, kv_cache): | |
| B, T, C = x.size() | |
| q = self.c_q(x).view(B, T, self.n_head, self.head_dim) | |
| k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) | |
| v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) | |
| cos, sin = cos_sin | |
| q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) | |
| q, k = norm(q), norm(k) | |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
| if kv_cache is not None: | |
| k, v = kv_cache.insert_kv(self.layer_idx, k, v) | |
| Tq = q.size(2) | |
| Tk = k.size(2) | |
| nrep = self.n_head // self.n_kv_head | |
| k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) | |
| if kv_cache is None or Tq == Tk: | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| elif Tq == 1: | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=False) | |
| else: | |
| attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) | |
| prefix_len = Tk - Tq | |
| if prefix_len > 0: | |
| attn_mask[:, :prefix_len] = True | |
| attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) | |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) | |
| y = y.transpose(1, 2).contiguous().view(B, T, -1) | |
| y = self.c_proj(y) | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) | |
| def forward(self, x): | |
| x = self.c_fc(x) | |
| x = F.relu(x).square() | |
| x = self.c_proj(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, config, layer_idx): | |
| super().__init__() | |
| self.attn = CausalSelfAttention(config, layer_idx) | |
| self.mlp = MLP(config) | |
| def forward(self, x, cos_sin, kv_cache): | |
| x = x + self.attn(norm(x), cos_sin, kv_cache) | |
| x = x + self.mlp(norm(x)) | |
| return x | |
| class GPT(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.transformer = nn.ModuleDict({ | |
| "wte": nn.Embedding(config.vocab_size, config.n_embd), | |
| "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), | |
| }) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| self.rotary_seq_len = config.sequence_len * 10 | |
| head_dim = config.n_embd // config.n_head | |
| cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) | |
| self.register_buffer("cos", cos, persistent=False) | |
| self.register_buffer("sin", sin, persistent=False) | |
| self.transformer.wte.to(dtype=torch.bfloat16) | |
| def init_weights(self): | |
| self.apply(self._init_weights) | |
| torch.nn.init.zeros_(self.lm_head.weight) | |
| for block in self.transformer.h: | |
| torch.nn.init.zeros_(block.mlp.c_proj.weight) | |
| torch.nn.init.zeros_(block.attn.c_proj.weight) | |
| head_dim = self.config.n_embd // self.config.n_head | |
| cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) | |
| self.cos, self.sin = cos, sin | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| fan_out = module.weight.size(0) | |
| fan_in = module.weight.size(1) | |
| std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=std) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) | |
| def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): | |
| if device is None: | |
| device = self.transformer.wte.weight.device | |
| channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) | |
| inv_freq = 1.0 / (base ** (channel_range / head_dim)) | |
| t = torch.arange(seq_len, dtype=torch.float32, device=device) | |
| freqs = torch.outer(t, inv_freq) | |
| cos, sin = freqs.cos(), freqs.sin() | |
| cos, sin = cos.bfloat16(), sin.bfloat16() | |
| cos, sin = cos[None, :, None, :], sin[None, :, None, :] | |
| return cos, sin | |
| def forward(self, idx, targets=None, kv_cache=None): | |
| B, T = idx.size() | |
| assert T <= self.cos.size(1) | |
| T0 = 0 if kv_cache is None else kv_cache.get_pos() | |
| cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] | |
| x = self.transformer.wte(idx) | |
| x = norm(x) | |
| for block in self.transformer.h: | |
| x = block(x, cos_sin, kv_cache) | |
| x = norm(x) | |
| softcap = 15 | |
| logits = self.lm_head(x) | |
| logits = softcap * torch.tanh(logits / softcap) | |
| return logits | |
| # ----------------------------------------------------------------------------- | |
| # Main script | |
| parser = argparse.ArgumentParser(description='Generate text with the model on CPU') | |
| parser.add_argument('--model-dir', type=str, required=True, help='Path to model directory containing model_*.pt, meta_*.json, and tokenizer.pkl') | |
| parser.add_argument('--prompt', type=str, default='Once upon a time', help='Prompt for generation') | |
| parser.add_argument('--max-tokens', type=int, default=100, help='Maximum number of tokens to generate') | |
| parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Temperature for generation') | |
| parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') | |
| args = parser.parse_args() | |
| # Set device to CPU | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| # Find the model file and meta file | |
| model_files = glob.glob(os.path.join(args.model_dir, "model_*.pt")) | |
| if not model_files: | |
| raise FileNotFoundError(f"No model files found in {args.model_dir}") | |
| model_file = model_files[0] | |
| meta_files = glob.glob(os.path.join(args.model_dir, "meta_*.json")) | |
| if not meta_files: | |
| raise FileNotFoundError(f"No meta files found in {args.model_dir}") | |
| meta_file = meta_files[0] | |
| print(f"Loading model from {model_file}") | |
| print(f"Loading metadata from {meta_file}") | |
| # Load metadata | |
| with open(meta_file, 'r') as f: | |
| meta = json.load(f) | |
| model_config_kwargs = meta["model_config"] | |
| print(f"Model config: {model_config_kwargs}") | |
| # Build the model | |
| model_config = GPTConfig(**model_config_kwargs) | |
| with torch.device("meta"): | |
| model = GPT(model_config) | |
| # Load the model weights | |
| print("Loading model weights (this may take a minute for a 2GB model)...") | |
| model_data = torch.load(model_file, map_location=device, weights_only=True) | |
| model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} | |
| # Convert all bfloat16 weights to float32 for CPU compatibility | |
| print("Converting model to float32 for CPU...") | |
| model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()} | |
| model.to_empty(device=device) | |
| model.init_weights() | |
| model.load_state_dict(model_data, strict=True, assign=True) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| # Load the tokenizer from the model directory | |
| print("Loading tokenizer...") | |
| tokenizer_path = os.path.join(args.model_dir, "tokenizer.pkl") | |
| if not os.path.exists(tokenizer_path): | |
| raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}. Please ensure tokenizer.pkl is in the model directory.") | |
| with open(tokenizer_path, "rb") as f: | |
| import tiktoken | |
| enc = pickle.load(f) | |
| print("Tokenizer loaded successfully!") | |
| # Get special token IDs for chat formatting | |
| try: | |
| try: | |
| bos_token_id = enc.encode_single_token("<|bos|>") | |
| except KeyError: | |
| bos_token_id = enc.encode_single_token("<|endoftext|>") # fallback | |
| user_start_id = enc.encode_single_token("<|user_start|>") | |
| user_end_id = enc.encode_single_token("<|user_end|>") | |
| assistant_start_id = enc.encode_single_token("<|assistant_start|>") | |
| assistant_end_id = enc.encode_single_token("<|assistant_end|>") | |
| stop_tokens = {bos_token_id, assistant_end_id} | |
| except KeyError as e: | |
| print(f"\nError: A required special token is missing from the tokenizer: {e}") | |
| print("This script is designed for nanochat models. The tokenizer might be incompatible.") | |
| exit(1) | |
| # Encode the prompt using the proper chat format | |
| prompt_tokens = enc.encode_ordinary(args.prompt) | |
| input_ids = [bos_token_id, user_start_id] + prompt_tokens + [user_end_id, assistant_start_id] | |
| print(f"\nPrompt: {args.prompt}") | |
| print(f"Formatted and encoded to {len(input_ids)} tokens") | |
| print() | |
| # Generate | |
| print("Generating...") | |
| print("-" * 50) | |
| print(args.prompt, end="", flush=True) | |
| x = torch.tensor([input_ids], dtype=torch.long, device=device) | |
| with torch.inference_mode(): | |
| for _ in range(args.max_tokens): | |
| # Forward pass | |
| logits = model(x) | |
| # Get logits for the last token | |
| logits = logits[:, -1, :] # (batch_size, vocab_size) | |
| # Apply temperature | |
| logits = logits / args.temperature | |
| # Apply top-k filtering | |
| if args.top_k > 0: | |
| v, _ = torch.topk(logits, min(args.top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| # Sample from the distribution | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Decode and print | |
| token_str = enc.decode([next_token.item()]) | |
| print(token_str, end="", flush=True) | |
| # Append to the sequence | |
| x = torch.cat([x, next_token], dim=1) | |
| # Stop if we generate a stop token (e.g., <|assistant_end|>) | |
| if next_token.item() in stop_tokens: | |
| break | |
| print() | |
| print("-" * 50) | |
| print("\nGeneration complete!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment