Created
December 29, 2025 14:11
-
-
Save chottokun/31c82ce969054dcb56ba6b743e674e8f to your computer and use it in GitHub Desktop.
mnist_ultrafast_v0-2.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "authorship_tag": "ABX9TyOK2sun+GEk145cVaBk/FW4", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/chottokun/31c82ce969054dcb56ba6b743e674e8f/mnist_ultrafast_v0-2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "VQ954pdocFbu", | |
| "outputId": "39125f6a-cdbb-403c-c277-663784c43260" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "device: cuda\n", | |
| "gpu: Tesla T4\n", | |
| "\n", | |
| "Loading data...\n", | |
| "X_bench: (5000000, 64)\n", | |
| "\n", | |
| "1. Training Teacher...\n", | |
| "\n", | |
| "2. Training Student (MSE Distillation, 200 Epochs)...\n", | |
| "Student trained.\n", | |
| "\n", | |
| "--- BENCHMARK ---\n", | |
| "Compile OK.\n", | |
| "Warming up...\n", | |
| "Running 50 iterations...\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.12/dist-packages/torch/_inductor/cudagraph_trees.py:2450: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n", | |
| " warnings.warn(\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\n", | |
| "Throughput (Wall): 238,384,478 samples/sec\n", | |
| "Throughput (GPU) : 614,920,146 samples/sec\n", | |
| "\n", | |
| "Checking Accuracy...\n", | |
| "Accuracy: 91.78% (Target: ≥96%)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import time\n", | |
| "import copy\n", | |
| "import gc\n", | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "import torchvision.datasets as datasets\n", | |
| "# -----------------------\n", | |
| "# Device / Performance toggles\n", | |
| "# -----------------------\n", | |
| "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
| "print(f\"device: {device}\")\n", | |
| "if device.type == \"cuda\":\n", | |
| " print(f\"gpu: {torch.cuda.get_device_name(0)}\")\n", | |
| " torch.backends.cudnn.benchmark = True\n", | |
| " if hasattr(torch, \"set_float32_matmul_precision\"):\n", | |
| " torch.set_float32_matmul_precision(\"high\")\n", | |
| "\n", | |
| "# -----------------------\n", | |
| "# Constants\n", | |
| "# -----------------------\n", | |
| "MNIST_MEAN = 0.1307\n", | |
| "MNIST_STD = 0.3081\n", | |
| "SEED = 42\n", | |
| "INPUT_DIM = 8 * 8\n", | |
| "STUDENT_DIM = 16\n", | |
| "TEACHER_DIM = 256\n", | |
| "LEAKY_SLOPE = 0.1\n", | |
| "\n", | |
| "torch.manual_seed(SEED)\n", | |
| "if torch.cuda.is_available():\n", | |
| " torch.cuda.manual_seed_all(SEED)\n", | |
| "\n", | |
| "# -----------------------\n", | |
| "# Models\n", | |
| "# -----------------------\n", | |
| "class MLP(nn.Module):\n", | |
| " def __init__(self, input_dim, hidden_dim):\n", | |
| " super().__init__()\n", | |
| " self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n", | |
| " self.bn1 = nn.BatchNorm1d(hidden_dim)\n", | |
| " self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = x.view(x.size(0), -1)\n", | |
| " x = self.fc1(x)\n", | |
| " x = self.bn1(x)\n", | |
| " x = F.leaky_relu(x, negative_slope=LEAKY_SLOPE)\n", | |
| " return self.fc2(x)\n", | |
| "\n", | |
| "class InferenceMLP(nn.Module):\n", | |
| " def __init__(self, input_dim, hidden_dim):\n", | |
| " super().__init__()\n", | |
| " self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n", | |
| " self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = x.view(x.size(0), -1)\n", | |
| " x = self.fc1(x)\n", | |
| " x = F.leaky_relu(x, negative_slope=LEAKY_SLOPE, inplace=True)\n", | |
| " return self.fc2(x)\n", | |
| "\n", | |
| "def fuse_model(train_model: MLP) -> InferenceMLP:\n", | |
| " fused = InferenceMLP(train_model.fc1.in_features, train_model.fc1.out_features)\n", | |
| " fused.fc2.weight.data = train_model.fc2.weight.data.clone()\n", | |
| " fused.fc2.bias.data = train_model.fc2.bias.data.clone()\n", | |
| "\n", | |
| " w = train_model.fc1.weight.data\n", | |
| " b = train_model.fc1.bias.data\n", | |
| " mean = train_model.bn1.running_mean\n", | |
| " var = train_model.bn1.running_var\n", | |
| " gamma = train_model.bn1.weight.data\n", | |
| " beta = train_model.bn1.bias.data\n", | |
| " eps = train_model.bn1.eps\n", | |
| " scale = gamma / torch.sqrt(var + eps)\n", | |
| " shift = beta - mean * scale\n", | |
| "\n", | |
| " fused.fc1.weight.data = w * scale.unsqueeze(1)\n", | |
| " fused.fc1.bias.data = b * scale + shift\n", | |
| " return fused\n", | |
| "\n", | |
| "def process_data(data: torch.Tensor, device: torch.device, use_fp16: bool = False) -> torch.Tensor:\n", | |
| " x = data.float() / 255.0\n", | |
| " x = x.unsqueeze(1)\n", | |
| " x = F.interpolate(x, size=(8, 8), mode='area')\n", | |
| " x = x.view(-1, INPUT_DIM)\n", | |
| " x = (x - MNIST_MEAN) / MNIST_STD\n", | |
| " if device.type == \"cuda\" and use_fp16:\n", | |
| " x = x.to(device, dtype=torch.float16, non_blocking=True)\n", | |
| " else:\n", | |
| " x = x.to(device, dtype=torch.float32, non_blocking=True)\n", | |
| " return x.contiguous()\n", | |
| "\n", | |
| "def cuda_event_time_ms(fn, iters: int) -> float:\n", | |
| " start = torch.cuda.Event(enable_timing=True)\n", | |
| " end = torch.cuda.Event(enable_timing=True)\n", | |
| " torch.cuda.synchronize()\n", | |
| " start.record()\n", | |
| " for _ in range(iters):\n", | |
| " fn()\n", | |
| " end.record()\n", | |
| " end.synchronize()\n", | |
| " return float(start.elapsed_time(end))\n", | |
| "\n", | |
| "def wall_time_sec(fn, iters: int, is_cuda: bool) -> float:\n", | |
| " if is_cuda:\n", | |
| " torch.cuda.synchronize()\n", | |
| " t0 = time.perf_counter()\n", | |
| " for _ in range(iters):\n", | |
| " fn()\n", | |
| " if is_cuda:\n", | |
| " torch.cuda.synchronize()\n", | |
| " return float(time.perf_counter() - t0)\n", | |
| "\n", | |
| "# -----------------------\n", | |
| "# Main\n", | |
| "# -----------------------\n", | |
| "print(\"\\nLoading data...\")\n", | |
| "train_ds = datasets.MNIST(root='./data', train=True, download=True)\n", | |
| "test_ds = datasets.MNIST(root='./data', train=False)\n", | |
| "\n", | |
| "X_train = process_data(train_ds.data, device, use_fp16=False)\n", | |
| "y_train = train_ds.targets.to(device, non_blocking=True)\n", | |
| "X_test = process_data(test_ds.data, device, use_fp16=(device.type == \"cuda\"))\n", | |
| "y_test = test_ds.targets.to(device, non_blocking=True)\n", | |
| "\n", | |
| "# 25M Saturation\n", | |
| "# NOTE: The original BENCH_MULTIPLIER of 2500 caused out-of-memory\n", | |
| "# errors in the execution environment. It has been reduced to a stable value.\n", | |
| "BENCH_MULTIPLIER = 500\n", | |
| "X_bench = X_test.repeat(BENCH_MULTIPLIER, 1)\n", | |
| "\n", | |
| "print(f\"X_bench: {tuple(X_bench.shape)}\")\n", | |
| "\n", | |
| "BATCH_SIZE = 4096\n", | |
| "indices = torch.arange(X_train.size(0), device=device)\n", | |
| "loss_fn_ce = nn.CrossEntropyLoss()\n", | |
| "loss_fn_mse = nn.MSELoss()\n", | |
| "\n", | |
| "# 1. Teacher\n", | |
| "print(\"\\n1. Training Teacher...\")\n", | |
| "teacher = MLP(INPUT_DIM, TEACHER_DIM).to(device)\n", | |
| "opt_teacher = torch.optim.Adam(teacher.parameters(), lr=0.01)\n", | |
| "teacher.train()\n", | |
| "for epoch in range(50):\n", | |
| " perm = indices[torch.randperm(indices.size(0))]\n", | |
| " for i in range(0, X_train.size(0), BATCH_SIZE):\n", | |
| " idx = perm[i : i + BATCH_SIZE]\n", | |
| " data = X_train[idx]\n", | |
| " target = y_train[idx]\n", | |
| " opt_teacher.zero_grad(set_to_none=True)\n", | |
| " logits = teacher(data)\n", | |
| " loss = loss_fn_ce(logits, target)\n", | |
| " loss.backward()\n", | |
| " opt_teacher.step()\n", | |
| "teacher.eval()\n", | |
| "\n", | |
| "# 2. Student\n", | |
| "print(\"\\n2. Training Student (MSE Distillation, 200 Epochs)...\")\n", | |
| "student = MLP(INPUT_DIM, STUDENT_DIM).to(device)\n", | |
| "opt_student = torch.optim.Adam(student.parameters(), lr=0.01)\n", | |
| "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt_student, T_max=200)\n", | |
| "student.train()\n", | |
| "\n", | |
| "for epoch in range(200):\n", | |
| " perm = indices[torch.randperm(indices.size(0))]\n", | |
| " for i in range(0, X_train.size(0), BATCH_SIZE):\n", | |
| " idx = perm[i : i + BATCH_SIZE]\n", | |
| " data = X_train[idx]\n", | |
| " target = y_train[idx]\n", | |
| "\n", | |
| " with torch.no_grad():\n", | |
| " teacher_logits = teacher(data)\n", | |
| "\n", | |
| " opt_student.zero_grad(set_to_none=True)\n", | |
| " student_logits = student(data)\n", | |
| "\n", | |
| " loss_ce = loss_fn_ce(student_logits, target)\n", | |
| " loss_mse = loss_fn_mse(student_logits, teacher_logits)\n", | |
| " loss = 0.5 * loss_ce + 0.5 * loss_mse\n", | |
| "\n", | |
| " loss.backward()\n", | |
| " opt_student.step()\n", | |
| " scheduler.step()\n", | |
| "\n", | |
| "print(\"Student trained.\")\n", | |
| "\n", | |
| "# Fuse\n", | |
| "model = fuse_model(student).to(device)\n", | |
| "if device.type == \"cuda\":\n", | |
| " model.half()\n", | |
| "model.eval()\n", | |
| "\n", | |
| "# Cleanup\n", | |
| "del teacher, opt_teacher, opt_student, student, loss_fn_ce, loss_fn_mse, indices, perm\n", | |
| "gc.collect()\n", | |
| "if device.type == \"cuda\":\n", | |
| " torch.cuda.empty_cache()\n", | |
| "\n", | |
| "# Benchmark\n", | |
| "print(\"\\n--- BENCHMARK ---\")\n", | |
| "inference_model = model\n", | |
| "if device.type == \"cuda\":\n", | |
| " try:\n", | |
| " # Using \"max-autotune\" now handles CUDA graphs internally.\n", | |
| " inference_model = torch.compile(model, mode=\"max-autotune\", fullgraph=True)\n", | |
| " print(\"Compile OK.\")\n", | |
| " except Exception as e:\n", | |
| " print(f\"torch.compile failed: {e}\")\n", | |
| "\n", | |
| "# Warmup\n", | |
| "print(\"Warming up...\")\n", | |
| "with torch.inference_mode():\n", | |
| " for _ in range(10):\n", | |
| " _ = inference_model(X_bench)\n", | |
| "if device.type == \"cuda\":\n", | |
| " torch.cuda.synchronize()\n", | |
| "\n", | |
| "num_runs = 50\n", | |
| "total_samples = X_bench.size(0) * num_runs\n", | |
| "print(f\"Running {num_runs} iterations...\")\n", | |
| "\n", | |
| "target_fn = lambda: inference_model(X_bench)\n", | |
| "\n", | |
| "wall_s = wall_time_sec(target_fn, num_runs, device.type == \"cuda\")\n", | |
| "print(f\"\\nThroughput (Wall): {total_samples / wall_s:,.0f} samples/sec\")\n", | |
| "\n", | |
| "if device.type == \"cuda\":\n", | |
| " gpu_s = cuda_event_time_ms(target_fn, num_runs) / 1000.0\n", | |
| " print(f\"Throughput (GPU) : {total_samples / gpu_s:,.0f} samples/sec\")\n", | |
| "\n", | |
| "print(\"\\nChecking Accuracy...\")\n", | |
| "with torch.inference_mode():\n", | |
| " acc = (model(X_test).argmax(dim=1) == y_test).float().mean() * 100\n", | |
| "print(f\"Accuracy: {acc:.2f}% (Target: ≥96%)\")\n" | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment