Skip to content

Instantly share code, notes, and snippets.

@N8python
N8python / pretrain.py
Created January 15, 2025 23:41
Simple character-level pretraining in MLX. Gets a roughly billion tokens/day for an 18M parameter model on one M3 Max.
import json
import random
import mlx.optimizers as optim
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
import time
from datetime import datetime