Skip to content

Instantly share code, notes, and snippets.

@chottokun
Created December 29, 2025 14:11
Show Gist options
  • Select an option

  • Save chottokun/31c82ce969054dcb56ba6b743e674e8f to your computer and use it in GitHub Desktop.

Select an option

Save chottokun/31c82ce969054dcb56ba6b743e674e8f to your computer and use it in GitHub Desktop.
mnist_ultrafast_v0-2.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyOK2sun+GEk145cVaBk/FW4",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/chottokun/31c82ce969054dcb56ba6b743e674e8f/mnist_ultrafast_v0-2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VQ954pdocFbu",
"outputId": "39125f6a-cdbb-403c-c277-663784c43260"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"device: cuda\n",
"gpu: Tesla T4\n",
"\n",
"Loading data...\n",
"X_bench: (5000000, 64)\n",
"\n",
"1. Training Teacher...\n",
"\n",
"2. Training Student (MSE Distillation, 200 Epochs)...\n",
"Student trained.\n",
"\n",
"--- BENCHMARK ---\n",
"Compile OK.\n",
"Warming up...\n",
"Running 50 iterations...\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.12/dist-packages/torch/_inductor/cudagraph_trees.py:2450: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n",
" warnings.warn(\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Throughput (Wall): 238,384,478 samples/sec\n",
"Throughput (GPU) : 614,920,146 samples/sec\n",
"\n",
"Checking Accuracy...\n",
"Accuracy: 91.78% (Target: ≥96%)\n"
]
}
],
"source": [
"import time\n",
"import copy\n",
"import gc\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision.datasets as datasets\n",
"# -----------------------\n",
"# Device / Performance toggles\n",
"# -----------------------\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"device: {device}\")\n",
"if device.type == \"cuda\":\n",
" print(f\"gpu: {torch.cuda.get_device_name(0)}\")\n",
" torch.backends.cudnn.benchmark = True\n",
" if hasattr(torch, \"set_float32_matmul_precision\"):\n",
" torch.set_float32_matmul_precision(\"high\")\n",
"\n",
"# -----------------------\n",
"# Constants\n",
"# -----------------------\n",
"MNIST_MEAN = 0.1307\n",
"MNIST_STD = 0.3081\n",
"SEED = 42\n",
"INPUT_DIM = 8 * 8\n",
"STUDENT_DIM = 16\n",
"TEACHER_DIM = 256\n",
"LEAKY_SLOPE = 0.1\n",
"\n",
"torch.manual_seed(SEED)\n",
"if torch.cuda.is_available():\n",
" torch.cuda.manual_seed_all(SEED)\n",
"\n",
"# -----------------------\n",
"# Models\n",
"# -----------------------\n",
"class MLP(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n",
" self.bn1 = nn.BatchNorm1d(hidden_dim)\n",
" self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc1(x)\n",
" x = self.bn1(x)\n",
" x = F.leaky_relu(x, negative_slope=LEAKY_SLOPE)\n",
" return self.fc2(x)\n",
"\n",
"class InferenceMLP(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n",
" self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc1(x)\n",
" x = F.leaky_relu(x, negative_slope=LEAKY_SLOPE, inplace=True)\n",
" return self.fc2(x)\n",
"\n",
"def fuse_model(train_model: MLP) -> InferenceMLP:\n",
" fused = InferenceMLP(train_model.fc1.in_features, train_model.fc1.out_features)\n",
" fused.fc2.weight.data = train_model.fc2.weight.data.clone()\n",
" fused.fc2.bias.data = train_model.fc2.bias.data.clone()\n",
"\n",
" w = train_model.fc1.weight.data\n",
" b = train_model.fc1.bias.data\n",
" mean = train_model.bn1.running_mean\n",
" var = train_model.bn1.running_var\n",
" gamma = train_model.bn1.weight.data\n",
" beta = train_model.bn1.bias.data\n",
" eps = train_model.bn1.eps\n",
" scale = gamma / torch.sqrt(var + eps)\n",
" shift = beta - mean * scale\n",
"\n",
" fused.fc1.weight.data = w * scale.unsqueeze(1)\n",
" fused.fc1.bias.data = b * scale + shift\n",
" return fused\n",
"\n",
"def process_data(data: torch.Tensor, device: torch.device, use_fp16: bool = False) -> torch.Tensor:\n",
" x = data.float() / 255.0\n",
" x = x.unsqueeze(1)\n",
" x = F.interpolate(x, size=(8, 8), mode='area')\n",
" x = x.view(-1, INPUT_DIM)\n",
" x = (x - MNIST_MEAN) / MNIST_STD\n",
" if device.type == \"cuda\" and use_fp16:\n",
" x = x.to(device, dtype=torch.float16, non_blocking=True)\n",
" else:\n",
" x = x.to(device, dtype=torch.float32, non_blocking=True)\n",
" return x.contiguous()\n",
"\n",
"def cuda_event_time_ms(fn, iters: int) -> float:\n",
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
" torch.cuda.synchronize()\n",
" start.record()\n",
" for _ in range(iters):\n",
" fn()\n",
" end.record()\n",
" end.synchronize()\n",
" return float(start.elapsed_time(end))\n",
"\n",
"def wall_time_sec(fn, iters: int, is_cuda: bool) -> float:\n",
" if is_cuda:\n",
" torch.cuda.synchronize()\n",
" t0 = time.perf_counter()\n",
" for _ in range(iters):\n",
" fn()\n",
" if is_cuda:\n",
" torch.cuda.synchronize()\n",
" return float(time.perf_counter() - t0)\n",
"\n",
"# -----------------------\n",
"# Main\n",
"# -----------------------\n",
"print(\"\\nLoading data...\")\n",
"train_ds = datasets.MNIST(root='./data', train=True, download=True)\n",
"test_ds = datasets.MNIST(root='./data', train=False)\n",
"\n",
"X_train = process_data(train_ds.data, device, use_fp16=False)\n",
"y_train = train_ds.targets.to(device, non_blocking=True)\n",
"X_test = process_data(test_ds.data, device, use_fp16=(device.type == \"cuda\"))\n",
"y_test = test_ds.targets.to(device, non_blocking=True)\n",
"\n",
"# 25M Saturation\n",
"# NOTE: The original BENCH_MULTIPLIER of 2500 caused out-of-memory\n",
"# errors in the execution environment. It has been reduced to a stable value.\n",
"BENCH_MULTIPLIER = 500\n",
"X_bench = X_test.repeat(BENCH_MULTIPLIER, 1)\n",
"\n",
"print(f\"X_bench: {tuple(X_bench.shape)}\")\n",
"\n",
"BATCH_SIZE = 4096\n",
"indices = torch.arange(X_train.size(0), device=device)\n",
"loss_fn_ce = nn.CrossEntropyLoss()\n",
"loss_fn_mse = nn.MSELoss()\n",
"\n",
"# 1. Teacher\n",
"print(\"\\n1. Training Teacher...\")\n",
"teacher = MLP(INPUT_DIM, TEACHER_DIM).to(device)\n",
"opt_teacher = torch.optim.Adam(teacher.parameters(), lr=0.01)\n",
"teacher.train()\n",
"for epoch in range(50):\n",
" perm = indices[torch.randperm(indices.size(0))]\n",
" for i in range(0, X_train.size(0), BATCH_SIZE):\n",
" idx = perm[i : i + BATCH_SIZE]\n",
" data = X_train[idx]\n",
" target = y_train[idx]\n",
" opt_teacher.zero_grad(set_to_none=True)\n",
" logits = teacher(data)\n",
" loss = loss_fn_ce(logits, target)\n",
" loss.backward()\n",
" opt_teacher.step()\n",
"teacher.eval()\n",
"\n",
"# 2. Student\n",
"print(\"\\n2. Training Student (MSE Distillation, 200 Epochs)...\")\n",
"student = MLP(INPUT_DIM, STUDENT_DIM).to(device)\n",
"opt_student = torch.optim.Adam(student.parameters(), lr=0.01)\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt_student, T_max=200)\n",
"student.train()\n",
"\n",
"for epoch in range(200):\n",
" perm = indices[torch.randperm(indices.size(0))]\n",
" for i in range(0, X_train.size(0), BATCH_SIZE):\n",
" idx = perm[i : i + BATCH_SIZE]\n",
" data = X_train[idx]\n",
" target = y_train[idx]\n",
"\n",
" with torch.no_grad():\n",
" teacher_logits = teacher(data)\n",
"\n",
" opt_student.zero_grad(set_to_none=True)\n",
" student_logits = student(data)\n",
"\n",
" loss_ce = loss_fn_ce(student_logits, target)\n",
" loss_mse = loss_fn_mse(student_logits, teacher_logits)\n",
" loss = 0.5 * loss_ce + 0.5 * loss_mse\n",
"\n",
" loss.backward()\n",
" opt_student.step()\n",
" scheduler.step()\n",
"\n",
"print(\"Student trained.\")\n",
"\n",
"# Fuse\n",
"model = fuse_model(student).to(device)\n",
"if device.type == \"cuda\":\n",
" model.half()\n",
"model.eval()\n",
"\n",
"# Cleanup\n",
"del teacher, opt_teacher, opt_student, student, loss_fn_ce, loss_fn_mse, indices, perm\n",
"gc.collect()\n",
"if device.type == \"cuda\":\n",
" torch.cuda.empty_cache()\n",
"\n",
"# Benchmark\n",
"print(\"\\n--- BENCHMARK ---\")\n",
"inference_model = model\n",
"if device.type == \"cuda\":\n",
" try:\n",
" # Using \"max-autotune\" now handles CUDA graphs internally.\n",
" inference_model = torch.compile(model, mode=\"max-autotune\", fullgraph=True)\n",
" print(\"Compile OK.\")\n",
" except Exception as e:\n",
" print(f\"torch.compile failed: {e}\")\n",
"\n",
"# Warmup\n",
"print(\"Warming up...\")\n",
"with torch.inference_mode():\n",
" for _ in range(10):\n",
" _ = inference_model(X_bench)\n",
"if device.type == \"cuda\":\n",
" torch.cuda.synchronize()\n",
"\n",
"num_runs = 50\n",
"total_samples = X_bench.size(0) * num_runs\n",
"print(f\"Running {num_runs} iterations...\")\n",
"\n",
"target_fn = lambda: inference_model(X_bench)\n",
"\n",
"wall_s = wall_time_sec(target_fn, num_runs, device.type == \"cuda\")\n",
"print(f\"\\nThroughput (Wall): {total_samples / wall_s:,.0f} samples/sec\")\n",
"\n",
"if device.type == \"cuda\":\n",
" gpu_s = cuda_event_time_ms(target_fn, num_runs) / 1000.0\n",
" print(f\"Throughput (GPU) : {total_samples / gpu_s:,.0f} samples/sec\")\n",
"\n",
"print(\"\\nChecking Accuracy...\")\n",
"with torch.inference_mode():\n",
" acc = (model(X_test).argmax(dim=1) == y_test).float().mean() * 100\n",
"print(f\"Accuracy: {acc:.2f}% (Target: ≥96%)\")\n"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment