Skip to content

Instantly share code, notes, and snippets.

@chottokun
Created December 29, 2025 13:50
Show Gist options
  • Select an option

  • Save chottokun/edf6ed823e63285e57c02546575621c9 to your computer and use it in GitHub Desktop.

Select an option

Save chottokun/edf6ed823e63285e57c02546575621c9 to your computer and use it in GitHub Desktop.
MNIST_ultrafast_v0改1.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/chottokun/edf6ed823e63285e57c02546575621c9/mnist_ultrafast_v0-1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YKzoICGezCO1",
"outputId": "1a664af4-5cb2-4f7f-ae86-1d854fa7890d"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"device: cuda\n",
"gpu: Tesla T4\n",
"\n",
"Loading data...\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.12/dist-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)\n",
" _C._set_float32_matmul_precision(precision)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"X_train_T: (60000, 64)\n",
"X_train_S: (60000, 64)\n",
"\n",
"1. Training Strong Teacher (Input=64, Hidden=1024) for 50 Epochs...\n",
"\n",
"2. Training Student (Input=64, Hidden=32) with KL Distillation (150 Epochs)...\n",
"Student trained.\n",
"Memory cleared for benchmark.\n",
"X_bench: (25000000, 64)\n",
"\n",
"--- BENCHMARK ---\n",
"Compile OK.\n",
"Graph OK.\n",
"Running 50 iterations (Batch=25000000)... \n",
"\n",
"Throughput (GPU) : 480,837,544 samples/sec\n",
"Throughput (Wall): 478,374,459 samples/sec\n",
"\n",
"Checking Accuracy (on 8x8 input)...\n",
"Accuracy: 96.79% (Target: ≥96%)\n"
]
}
],
"source": [
"import time\n",
"import copy\n",
"import gc\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision.datasets as datasets\n",
"\n",
"# -----------------------\n",
"# Device / Performance toggles\n",
"# -----------------------\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"device: {device}\")\n",
"if device.type == \"cuda\":\n",
" print(f\"gpu: {torch.cuda.get_device_name(0)}\")\n",
" torch.backends.cudnn.benchmark = True\n",
" if hasattr(torch, \"set_float32_matmul_precision\"):\n",
" torch.set_float32_matmul_precision(\"high\")\n",
"\n",
"# -----------------------\n",
"# Constants\n",
"# -----------------------\n",
"MNIST_MEAN = 0.1307\n",
"MNIST_STD = 0.3081\n",
"SEED = 42\n",
"\n",
"# Optimization: Teacher sees 8x8 (64 features)\n",
"TEACHER_INPUT_DIM = 8 * 8\n",
"\n",
"# Optimization: Student sees 8x8 (64 features)\n",
"# 64 features (128 bytes) is perfectly memory aligned.\n",
"# FASTER than 48 features due to cache line alignment.\n",
"STUDENT_INPUT_DIM = 8 * 8\n",
"\n",
"# Optimization: Hidden Dim 32\n",
"# 32 is perfectly aligned to Warp Size (32 threads).\n",
"STUDENT_HIDDEN = 32\n",
"TEACHER_HIDDEN = 1024\n",
"\n",
"torch.manual_seed(SEED)\n",
"if torch.cuda.is_available():\n",
" torch.cuda.manual_seed_all(SEED)\n",
"\n",
"# -----------------------\n",
"# Models\n",
"# -----------------------\n",
"class TeacherMLP(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n",
" self.bn1 = nn.BatchNorm1d(hidden_dim)\n",
" self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc1(x)\n",
" x = self.bn1(x)\n",
" x = F.silu(x)\n",
" return self.fc2(x)\n",
"\n",
"class StudentMLP(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n",
" self.bn1 = nn.BatchNorm1d(hidden_dim)\n",
" self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc1(x)\n",
" x = self.bn1(x)\n",
" # ReLU is hardware-efficient (max(0,x))\n",
" x = F.relu(x)\n",
" return self.fc2(x)\n",
"\n",
"class InferenceMLP(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n",
" self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc1(x)\n",
" x = F.relu(x, inplace=True)\n",
" return self.fc2(x)\n",
"\n",
"def fuse_model(train_model: StudentMLP) -> InferenceMLP:\n",
" fused = InferenceMLP(train_model.fc1.in_features, train_model.fc1.out_features)\n",
" fused.fc2.weight.data = train_model.fc2.weight.data.clone()\n",
" fused.fc2.bias.data = train_model.fc2.bias.data.clone()\n",
"\n",
" w = train_model.fc1.weight.data\n",
" b = train_model.fc1.bias.data\n",
" mean = train_model.bn1.running_mean\n",
" var = train_model.bn1.running_var\n",
" gamma = train_model.bn1.weight.data\n",
" beta = train_model.bn1.bias.data\n",
" eps = train_model.bn1.eps\n",
" scale = gamma / torch.sqrt(var + eps)\n",
" shift = beta - mean * scale\n",
"\n",
" fused.fc1.weight.data = w * scale.unsqueeze(1)\n",
" fused.fc1.bias.data = b * scale + shift\n",
" return fused\n",
"\n",
"# -----------------------\n",
"# Data Processing\n",
"# -----------------------\n",
"def process_data(data: torch.Tensor, h: int, w: int, device: torch.device, use_fp16: bool = False) -> torch.Tensor:\n",
" \"\"\"Process raw data to specific HxW size\"\"\"\n",
" x = data.float() / 255.0\n",
" x = x.unsqueeze(1)\n",
"\n",
" # Center Crop 20x20 first\n",
" x = x[:, :, 4:24, 4:24]\n",
" # Resize to target size\n",
" x = F.interpolate(x, size=(h, w), mode='area')\n",
"\n",
" x = x.view(-1, h * w)\n",
" x = (x - MNIST_MEAN) / MNIST_STD\n",
" if device.type == \"cuda\" and use_fp16:\n",
" x = x.to(device, dtype=torch.float16, non_blocking=True)\n",
" else:\n",
" x = x.to(device, dtype=torch.float32, non_blocking=True)\n",
" return x.contiguous()\n",
"\n",
"class CUDAGraphRunner:\n",
" def __init__(self, model: nn.Module, example_x: torch.Tensor, warmup: int = 20):\n",
" self.model = model.eval()\n",
" self.static_x = example_x.clone()\n",
" with torch.inference_mode():\n",
" for _ in range(warmup):\n",
" _ = self.model(self.static_x)\n",
" torch.cuda.synchronize()\n",
" self.g = torch.cuda.CUDAGraph()\n",
" torch.cuda.synchronize()\n",
" with torch.inference_mode():\n",
" with torch.cuda.graph(self.g):\n",
" self.static_y = self.model(self.static_x)\n",
" torch.cuda.synchronize()\n",
" @torch.inference_mode()\n",
" def replay(self):\n",
" self.g.replay()\n",
" return self.static_y\n",
"\n",
"def cuda_event_time_ms(fn, iters: int) -> float:\n",
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
" torch.cuda.synchronize()\n",
" start.record()\n",
" for _ in range(iters):\n",
" fn()\n",
" end.record()\n",
" end.synchronize()\n",
" return float(start.elapsed_time(end))\n",
"\n",
"def wall_time_sec(fn, iters: int, is_cuda: bool) -> float:\n",
" if is_cuda:\n",
" torch.cuda.synchronize()\n",
" t0 = time.perf_counter()\n",
" for _ in range(iters):\n",
" fn()\n",
" if is_cuda:\n",
" torch.cuda.synchronize()\n",
" return float(time.perf_counter() - t0)\n",
"\n",
"# -----------------------\n",
"# Main\n",
"# -----------------------\n",
"# 1. Clear Memory\n",
"gc.collect()\n",
"torch.cuda.empty_cache()\n",
"\n",
"print(\"\\nLoading data...\")\n",
"train_ds = datasets.MNIST(root='./data', train=True, download=True)\n",
"test_ds = datasets.MNIST(root='./data', train=False)\n",
"\n",
"# Teacher Data: 8x8 (64 features)\n",
"X_train_T = process_data(train_ds.data, 8, 8, device, use_fp16=False)\n",
"# Student Data: 8x8 (64 features)\n",
"X_train_S = process_data(train_ds.data, 8, 8, device, use_fp16=False)\n",
"\n",
"y_train = train_ds.targets.to(device, non_blocking=True)\n",
"\n",
"# Test Data (Student scale)\n",
"X_test_S = process_data(test_ds.data, 8, 8, device, use_fp16=(device.type == \"cuda\"))\n",
"y_test = test_ds.targets.to(device, non_blocking=True)\n",
"\n",
"print(f\"X_train_T: {tuple(X_train_T.shape)}\")\n",
"print(f\"X_train_S: {tuple(X_train_S.shape)}\")\n",
"\n",
"BATCH_SIZE = 4096\n",
"indices = torch.arange(X_train_T.size(0), device=device)\n",
"loss_fn_ce = nn.CrossEntropyLoss()\n",
"loss_fn_kl = nn.KLDivLoss(reduction='batchmean')\n",
"\n",
"# 1. Teacher (8x8)\n",
"print(\"\\n1. Training Strong Teacher (Input=64, Hidden=1024) for 50 Epochs...\")\n",
"teacher = TeacherMLP(TEACHER_INPUT_DIM, TEACHER_HIDDEN).to(device)\n",
"opt_teacher = torch.optim.Adam(teacher.parameters(), lr=0.01)\n",
"teacher.train()\n",
"for epoch in range(50):\n",
" perm = indices[torch.randperm(indices.size(0))]\n",
" for i in range(0, X_train_T.size(0), BATCH_SIZE):\n",
" idx = perm[i : i + BATCH_SIZE]\n",
" data = X_train_T[idx]\n",
" target = y_train[idx]\n",
" opt_teacher.zero_grad(set_to_none=True)\n",
" logits = teacher(data)\n",
" loss = loss_fn_ce(logits, target)\n",
" loss.backward()\n",
" opt_teacher.step()\n",
"teacher.eval()\n",
"\n",
"# 2. Student (8x8) - Distillation\n",
"print(\"\\n2. Training Student (Input=64, Hidden=32) with KL Distillation (150 Epochs)...\")\n",
"student = StudentMLP(STUDENT_INPUT_DIM, STUDENT_HIDDEN).to(device)\n",
"opt_student = torch.optim.Adam(student.parameters(), lr=0.01)\n",
"scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
" opt_student,\n",
" max_lr=0.01,\n",
" steps_per_epoch=len(X_train_T)//BATCH_SIZE+1,\n",
" epochs=150\n",
")\n",
"student.train()\n",
"\n",
"T = 4.0\n",
"ALPHA = 0.5\n",
"\n",
"for epoch in range(150):\n",
" perm = indices[torch.randperm(indices.size(0))]\n",
" for i in range(0, X_train_T.size(0), BATCH_SIZE):\n",
" idx = perm[i : i + BATCH_SIZE]\n",
"\n",
" data_t = X_train_T[idx]\n",
" data_s = X_train_S[idx]\n",
" target = y_train[idx]\n",
"\n",
" with torch.no_grad():\n",
" teacher_logits = teacher(data_t)\n",
"\n",
" opt_student.zero_grad(set_to_none=True)\n",
" student_logits = student(data_s)\n",
"\n",
" # KL Loss Distillation (Proven better than MSE for accuracy here)\n",
" loss_ce = loss_fn_ce(student_logits, target)\n",
" loss_kl = loss_fn_kl(F.log_softmax(student_logits/T, dim=1), F.softmax(teacher_logits/T, dim=1)) * (T*T)\n",
"\n",
" loss = (1. - ALPHA) * loss_ce + ALPHA * loss_kl\n",
"\n",
" loss.backward()\n",
" opt_student.step()\n",
" scheduler.step()\n",
"\n",
"print(\"Student trained.\")\n",
"\n",
"# Fuse\n",
"model = fuse_model(student).to(device)\n",
"model.half().eval()\n",
"\n",
"# Cleanup (CRITICAL)\n",
"del teacher, opt_teacher, opt_student, student, loss_fn_ce, loss_fn_kl, indices, perm\n",
"del X_train_T, X_train_S, y_train\n",
"gc.collect()\n",
"torch.cuda.empty_cache()\n",
"print(\"Memory cleared for benchmark.\")\n",
"\n",
"# 25M Saturation (Safe Batch Size)\n",
"BENCH_MULTIPLIER = 2500\n",
"# Use Half directly to save memory\n",
"X_bench = X_test_S.half().repeat(BENCH_MULTIPLIER, 1)\n",
"print(f\"X_bench: {tuple(X_bench.shape)}\")\n",
"\n",
"# Benchmark\n",
"print(\"\\n--- BENCHMARK ---\")\n",
"compiled_model = None\n",
"try:\n",
" # max-autotune-no-cudagraphs helps fusion\n",
" compiled_model = torch.compile(model, mode=\"max-autotune-no-cudagraphs\", fullgraph=True)\n",
" with torch.inference_mode():\n",
" _ = compiled_model(X_bench)\n",
" torch.cuda.synchronize()\n",
" print(\"Compile OK.\")\n",
"except:\n",
" compiled_model = None\n",
"\n",
"inference_model = compiled_model if compiled_model is not None else model\n",
"\n",
"runner = None\n",
"try:\n",
" runner = CUDAGraphRunner(inference_model, X_bench, warmup=10)\n",
" print(\"Graph OK.\")\n",
"except:\n",
" pass\n",
"\n",
"num_runs = 50\n",
"total_samples = X_bench.size(0) * num_runs\n",
"print(f\"Running {num_runs} iterations (Batch={X_bench.size(0)})... \")\n",
"\n",
"target_fn = runner.replay if runner else lambda: inference_model(X_bench)\n",
"for _ in range(5):\n",
" target_fn()\n",
"torch.cuda.synchronize()\n",
"\n",
"gpu_s = cuda_event_time_ms(target_fn, num_runs) / 1000.0\n",
"wall_s = wall_time_sec(target_fn, num_runs, True)\n",
"\n",
"print(f\"\\nThroughput (GPU) : {total_samples / gpu_s:,.0f} samples/sec\")\n",
"print(f\"Throughput (Wall): {total_samples / wall_s:,.0f} samples/sec\")\n",
"\n",
"print(\"\\nChecking Accuracy (on 8x8 input)...\")\n",
"with torch.inference_mode():\n",
" acc = (model(X_test_S.half()).argmax(dim=1) == y_test).float().mean() * 100\n",
"print(f\"Accuracy: {acc:.2f}% (Target: ≥96%)\")\n"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment