Created
December 29, 2025 13:50
-
-
Save chottokun/edf6ed823e63285e57c02546575621c9 to your computer and use it in GitHub Desktop.
MNIST_ultrafast_v0改1.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/chottokun/edf6ed823e63285e57c02546575621c9/mnist_ultrafast_v0-1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "YKzoICGezCO1", | |
| "outputId": "1a664af4-5cb2-4f7f-ae86-1d854fa7890d" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "device: cuda\n", | |
| "gpu: Tesla T4\n", | |
| "\n", | |
| "Loading data...\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.12/dist-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)\n", | |
| " _C._set_float32_matmul_precision(precision)\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "X_train_T: (60000, 64)\n", | |
| "X_train_S: (60000, 64)\n", | |
| "\n", | |
| "1. Training Strong Teacher (Input=64, Hidden=1024) for 50 Epochs...\n", | |
| "\n", | |
| "2. Training Student (Input=64, Hidden=32) with KL Distillation (150 Epochs)...\n", | |
| "Student trained.\n", | |
| "Memory cleared for benchmark.\n", | |
| "X_bench: (25000000, 64)\n", | |
| "\n", | |
| "--- BENCHMARK ---\n", | |
| "Compile OK.\n", | |
| "Graph OK.\n", | |
| "Running 50 iterations (Batch=25000000)... \n", | |
| "\n", | |
| "Throughput (GPU) : 480,837,544 samples/sec\n", | |
| "Throughput (Wall): 478,374,459 samples/sec\n", | |
| "\n", | |
| "Checking Accuracy (on 8x8 input)...\n", | |
| "Accuracy: 96.79% (Target: ≥96%)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import time\n", | |
| "import copy\n", | |
| "import gc\n", | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "import torchvision.datasets as datasets\n", | |
| "\n", | |
| "# -----------------------\n", | |
| "# Device / Performance toggles\n", | |
| "# -----------------------\n", | |
| "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
| "print(f\"device: {device}\")\n", | |
| "if device.type == \"cuda\":\n", | |
| " print(f\"gpu: {torch.cuda.get_device_name(0)}\")\n", | |
| " torch.backends.cudnn.benchmark = True\n", | |
| " if hasattr(torch, \"set_float32_matmul_precision\"):\n", | |
| " torch.set_float32_matmul_precision(\"high\")\n", | |
| "\n", | |
| "# -----------------------\n", | |
| "# Constants\n", | |
| "# -----------------------\n", | |
| "MNIST_MEAN = 0.1307\n", | |
| "MNIST_STD = 0.3081\n", | |
| "SEED = 42\n", | |
| "\n", | |
| "# Optimization: Teacher sees 8x8 (64 features)\n", | |
| "TEACHER_INPUT_DIM = 8 * 8\n", | |
| "\n", | |
| "# Optimization: Student sees 8x8 (64 features)\n", | |
| "# 64 features (128 bytes) is perfectly memory aligned.\n", | |
| "# FASTER than 48 features due to cache line alignment.\n", | |
| "STUDENT_INPUT_DIM = 8 * 8\n", | |
| "\n", | |
| "# Optimization: Hidden Dim 32\n", | |
| "# 32 is perfectly aligned to Warp Size (32 threads).\n", | |
| "STUDENT_HIDDEN = 32\n", | |
| "TEACHER_HIDDEN = 1024\n", | |
| "\n", | |
| "torch.manual_seed(SEED)\n", | |
| "if torch.cuda.is_available():\n", | |
| " torch.cuda.manual_seed_all(SEED)\n", | |
| "\n", | |
| "# -----------------------\n", | |
| "# Models\n", | |
| "# -----------------------\n", | |
| "class TeacherMLP(nn.Module):\n", | |
| " def __init__(self, input_dim, hidden_dim):\n", | |
| " super().__init__()\n", | |
| " self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n", | |
| " self.bn1 = nn.BatchNorm1d(hidden_dim)\n", | |
| " self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = x.view(x.size(0), -1)\n", | |
| " x = self.fc1(x)\n", | |
| " x = self.bn1(x)\n", | |
| " x = F.silu(x)\n", | |
| " return self.fc2(x)\n", | |
| "\n", | |
| "class StudentMLP(nn.Module):\n", | |
| " def __init__(self, input_dim, hidden_dim):\n", | |
| " super().__init__()\n", | |
| " self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n", | |
| " self.bn1 = nn.BatchNorm1d(hidden_dim)\n", | |
| " self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = x.view(x.size(0), -1)\n", | |
| " x = self.fc1(x)\n", | |
| " x = self.bn1(x)\n", | |
| " # ReLU is hardware-efficient (max(0,x))\n", | |
| " x = F.relu(x)\n", | |
| " return self.fc2(x)\n", | |
| "\n", | |
| "class InferenceMLP(nn.Module):\n", | |
| " def __init__(self, input_dim, hidden_dim):\n", | |
| " super().__init__()\n", | |
| " self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)\n", | |
| " self.fc2 = nn.Linear(hidden_dim, 10, bias=True)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = x.view(x.size(0), -1)\n", | |
| " x = self.fc1(x)\n", | |
| " x = F.relu(x, inplace=True)\n", | |
| " return self.fc2(x)\n", | |
| "\n", | |
| "def fuse_model(train_model: StudentMLP) -> InferenceMLP:\n", | |
| " fused = InferenceMLP(train_model.fc1.in_features, train_model.fc1.out_features)\n", | |
| " fused.fc2.weight.data = train_model.fc2.weight.data.clone()\n", | |
| " fused.fc2.bias.data = train_model.fc2.bias.data.clone()\n", | |
| "\n", | |
| " w = train_model.fc1.weight.data\n", | |
| " b = train_model.fc1.bias.data\n", | |
| " mean = train_model.bn1.running_mean\n", | |
| " var = train_model.bn1.running_var\n", | |
| " gamma = train_model.bn1.weight.data\n", | |
| " beta = train_model.bn1.bias.data\n", | |
| " eps = train_model.bn1.eps\n", | |
| " scale = gamma / torch.sqrt(var + eps)\n", | |
| " shift = beta - mean * scale\n", | |
| "\n", | |
| " fused.fc1.weight.data = w * scale.unsqueeze(1)\n", | |
| " fused.fc1.bias.data = b * scale + shift\n", | |
| " return fused\n", | |
| "\n", | |
| "# -----------------------\n", | |
| "# Data Processing\n", | |
| "# -----------------------\n", | |
| "def process_data(data: torch.Tensor, h: int, w: int, device: torch.device, use_fp16: bool = False) -> torch.Tensor:\n", | |
| " \"\"\"Process raw data to specific HxW size\"\"\"\n", | |
| " x = data.float() / 255.0\n", | |
| " x = x.unsqueeze(1)\n", | |
| "\n", | |
| " # Center Crop 20x20 first\n", | |
| " x = x[:, :, 4:24, 4:24]\n", | |
| " # Resize to target size\n", | |
| " x = F.interpolate(x, size=(h, w), mode='area')\n", | |
| "\n", | |
| " x = x.view(-1, h * w)\n", | |
| " x = (x - MNIST_MEAN) / MNIST_STD\n", | |
| " if device.type == \"cuda\" and use_fp16:\n", | |
| " x = x.to(device, dtype=torch.float16, non_blocking=True)\n", | |
| " else:\n", | |
| " x = x.to(device, dtype=torch.float32, non_blocking=True)\n", | |
| " return x.contiguous()\n", | |
| "\n", | |
| "class CUDAGraphRunner:\n", | |
| " def __init__(self, model: nn.Module, example_x: torch.Tensor, warmup: int = 20):\n", | |
| " self.model = model.eval()\n", | |
| " self.static_x = example_x.clone()\n", | |
| " with torch.inference_mode():\n", | |
| " for _ in range(warmup):\n", | |
| " _ = self.model(self.static_x)\n", | |
| " torch.cuda.synchronize()\n", | |
| " self.g = torch.cuda.CUDAGraph()\n", | |
| " torch.cuda.synchronize()\n", | |
| " with torch.inference_mode():\n", | |
| " with torch.cuda.graph(self.g):\n", | |
| " self.static_y = self.model(self.static_x)\n", | |
| " torch.cuda.synchronize()\n", | |
| " @torch.inference_mode()\n", | |
| " def replay(self):\n", | |
| " self.g.replay()\n", | |
| " return self.static_y\n", | |
| "\n", | |
| "def cuda_event_time_ms(fn, iters: int) -> float:\n", | |
| " start = torch.cuda.Event(enable_timing=True)\n", | |
| " end = torch.cuda.Event(enable_timing=True)\n", | |
| " torch.cuda.synchronize()\n", | |
| " start.record()\n", | |
| " for _ in range(iters):\n", | |
| " fn()\n", | |
| " end.record()\n", | |
| " end.synchronize()\n", | |
| " return float(start.elapsed_time(end))\n", | |
| "\n", | |
| "def wall_time_sec(fn, iters: int, is_cuda: bool) -> float:\n", | |
| " if is_cuda:\n", | |
| " torch.cuda.synchronize()\n", | |
| " t0 = time.perf_counter()\n", | |
| " for _ in range(iters):\n", | |
| " fn()\n", | |
| " if is_cuda:\n", | |
| " torch.cuda.synchronize()\n", | |
| " return float(time.perf_counter() - t0)\n", | |
| "\n", | |
| "# -----------------------\n", | |
| "# Main\n", | |
| "# -----------------------\n", | |
| "# 1. Clear Memory\n", | |
| "gc.collect()\n", | |
| "torch.cuda.empty_cache()\n", | |
| "\n", | |
| "print(\"\\nLoading data...\")\n", | |
| "train_ds = datasets.MNIST(root='./data', train=True, download=True)\n", | |
| "test_ds = datasets.MNIST(root='./data', train=False)\n", | |
| "\n", | |
| "# Teacher Data: 8x8 (64 features)\n", | |
| "X_train_T = process_data(train_ds.data, 8, 8, device, use_fp16=False)\n", | |
| "# Student Data: 8x8 (64 features)\n", | |
| "X_train_S = process_data(train_ds.data, 8, 8, device, use_fp16=False)\n", | |
| "\n", | |
| "y_train = train_ds.targets.to(device, non_blocking=True)\n", | |
| "\n", | |
| "# Test Data (Student scale)\n", | |
| "X_test_S = process_data(test_ds.data, 8, 8, device, use_fp16=(device.type == \"cuda\"))\n", | |
| "y_test = test_ds.targets.to(device, non_blocking=True)\n", | |
| "\n", | |
| "print(f\"X_train_T: {tuple(X_train_T.shape)}\")\n", | |
| "print(f\"X_train_S: {tuple(X_train_S.shape)}\")\n", | |
| "\n", | |
| "BATCH_SIZE = 4096\n", | |
| "indices = torch.arange(X_train_T.size(0), device=device)\n", | |
| "loss_fn_ce = nn.CrossEntropyLoss()\n", | |
| "loss_fn_kl = nn.KLDivLoss(reduction='batchmean')\n", | |
| "\n", | |
| "# 1. Teacher (8x8)\n", | |
| "print(\"\\n1. Training Strong Teacher (Input=64, Hidden=1024) for 50 Epochs...\")\n", | |
| "teacher = TeacherMLP(TEACHER_INPUT_DIM, TEACHER_HIDDEN).to(device)\n", | |
| "opt_teacher = torch.optim.Adam(teacher.parameters(), lr=0.01)\n", | |
| "teacher.train()\n", | |
| "for epoch in range(50):\n", | |
| " perm = indices[torch.randperm(indices.size(0))]\n", | |
| " for i in range(0, X_train_T.size(0), BATCH_SIZE):\n", | |
| " idx = perm[i : i + BATCH_SIZE]\n", | |
| " data = X_train_T[idx]\n", | |
| " target = y_train[idx]\n", | |
| " opt_teacher.zero_grad(set_to_none=True)\n", | |
| " logits = teacher(data)\n", | |
| " loss = loss_fn_ce(logits, target)\n", | |
| " loss.backward()\n", | |
| " opt_teacher.step()\n", | |
| "teacher.eval()\n", | |
| "\n", | |
| "# 2. Student (8x8) - Distillation\n", | |
| "print(\"\\n2. Training Student (Input=64, Hidden=32) with KL Distillation (150 Epochs)...\")\n", | |
| "student = StudentMLP(STUDENT_INPUT_DIM, STUDENT_HIDDEN).to(device)\n", | |
| "opt_student = torch.optim.Adam(student.parameters(), lr=0.01)\n", | |
| "scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", | |
| " opt_student,\n", | |
| " max_lr=0.01,\n", | |
| " steps_per_epoch=len(X_train_T)//BATCH_SIZE+1,\n", | |
| " epochs=150\n", | |
| ")\n", | |
| "student.train()\n", | |
| "\n", | |
| "T = 4.0\n", | |
| "ALPHA = 0.5\n", | |
| "\n", | |
| "for epoch in range(150):\n", | |
| " perm = indices[torch.randperm(indices.size(0))]\n", | |
| " for i in range(0, X_train_T.size(0), BATCH_SIZE):\n", | |
| " idx = perm[i : i + BATCH_SIZE]\n", | |
| "\n", | |
| " data_t = X_train_T[idx]\n", | |
| " data_s = X_train_S[idx]\n", | |
| " target = y_train[idx]\n", | |
| "\n", | |
| " with torch.no_grad():\n", | |
| " teacher_logits = teacher(data_t)\n", | |
| "\n", | |
| " opt_student.zero_grad(set_to_none=True)\n", | |
| " student_logits = student(data_s)\n", | |
| "\n", | |
| " # KL Loss Distillation (Proven better than MSE for accuracy here)\n", | |
| " loss_ce = loss_fn_ce(student_logits, target)\n", | |
| " loss_kl = loss_fn_kl(F.log_softmax(student_logits/T, dim=1), F.softmax(teacher_logits/T, dim=1)) * (T*T)\n", | |
| "\n", | |
| " loss = (1. - ALPHA) * loss_ce + ALPHA * loss_kl\n", | |
| "\n", | |
| " loss.backward()\n", | |
| " opt_student.step()\n", | |
| " scheduler.step()\n", | |
| "\n", | |
| "print(\"Student trained.\")\n", | |
| "\n", | |
| "# Fuse\n", | |
| "model = fuse_model(student).to(device)\n", | |
| "model.half().eval()\n", | |
| "\n", | |
| "# Cleanup (CRITICAL)\n", | |
| "del teacher, opt_teacher, opt_student, student, loss_fn_ce, loss_fn_kl, indices, perm\n", | |
| "del X_train_T, X_train_S, y_train\n", | |
| "gc.collect()\n", | |
| "torch.cuda.empty_cache()\n", | |
| "print(\"Memory cleared for benchmark.\")\n", | |
| "\n", | |
| "# 25M Saturation (Safe Batch Size)\n", | |
| "BENCH_MULTIPLIER = 2500\n", | |
| "# Use Half directly to save memory\n", | |
| "X_bench = X_test_S.half().repeat(BENCH_MULTIPLIER, 1)\n", | |
| "print(f\"X_bench: {tuple(X_bench.shape)}\")\n", | |
| "\n", | |
| "# Benchmark\n", | |
| "print(\"\\n--- BENCHMARK ---\")\n", | |
| "compiled_model = None\n", | |
| "try:\n", | |
| " # max-autotune-no-cudagraphs helps fusion\n", | |
| " compiled_model = torch.compile(model, mode=\"max-autotune-no-cudagraphs\", fullgraph=True)\n", | |
| " with torch.inference_mode():\n", | |
| " _ = compiled_model(X_bench)\n", | |
| " torch.cuda.synchronize()\n", | |
| " print(\"Compile OK.\")\n", | |
| "except:\n", | |
| " compiled_model = None\n", | |
| "\n", | |
| "inference_model = compiled_model if compiled_model is not None else model\n", | |
| "\n", | |
| "runner = None\n", | |
| "try:\n", | |
| " runner = CUDAGraphRunner(inference_model, X_bench, warmup=10)\n", | |
| " print(\"Graph OK.\")\n", | |
| "except:\n", | |
| " pass\n", | |
| "\n", | |
| "num_runs = 50\n", | |
| "total_samples = X_bench.size(0) * num_runs\n", | |
| "print(f\"Running {num_runs} iterations (Batch={X_bench.size(0)})... \")\n", | |
| "\n", | |
| "target_fn = runner.replay if runner else lambda: inference_model(X_bench)\n", | |
| "for _ in range(5):\n", | |
| " target_fn()\n", | |
| "torch.cuda.synchronize()\n", | |
| "\n", | |
| "gpu_s = cuda_event_time_ms(target_fn, num_runs) / 1000.0\n", | |
| "wall_s = wall_time_sec(target_fn, num_runs, True)\n", | |
| "\n", | |
| "print(f\"\\nThroughput (GPU) : {total_samples / gpu_s:,.0f} samples/sec\")\n", | |
| "print(f\"Throughput (Wall): {total_samples / wall_s:,.0f} samples/sec\")\n", | |
| "\n", | |
| "print(\"\\nChecking Accuracy (on 8x8 input)...\")\n", | |
| "with torch.inference_mode():\n", | |
| " acc = (model(X_test_S.half()).argmax(dim=1) == y_test).float().mean() * 100\n", | |
| "print(f\"Accuracy: {acc:.2f}% (Target: ≥96%)\")\n" | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment