Created
May 12, 2025 07:47
-
-
Save MagedSaeed/5458660ac7e9e002a4157afa0bb7e6c8 to your computer and use it in GitHub Desktop.
train_mnist.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
print("=" * 80) | |
print("MNIST Training on SLURM") | |
print("=" * 80) | |
# Print SLURM environment variables if running on SLURM | |
if 'SLURM_JOB_ID' in os.environ: | |
print(f"SLURM Job ID: {os.environ['SLURM_JOB_ID']}") | |
print(f"SLURM Job Name: {os.environ['SLURM_JOB_NAME']}") | |
print(f"SLURM Allocated Nodes: {os.environ['SLURM_JOB_NODELIST']}") | |
print(f"SLURM Allocated CPUs: {os.environ['SLURM_CPUS_PER_TASK']}") | |
if 'SLURM_GPUS' in os.environ: | |
print(f"SLURM Allocated GPUs: {os.environ['SLURM_GPUS']}") | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
if device.type == 'cuda': | |
print(f"CUDA Device: {torch.cuda.get_device_name(0)}") | |
print(f"CUDA Version: {torch.version.cuda}") | |
print(f"CUDA Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.1f} MB") | |
print(f"CUDA Memory Cached: {torch.cuda.memory_reserved(0) / 1024**2:.1f} MB") | |
# Hyperparameters | |
num_epochs = 5 | |
batch_size = 64 | |
learning_rate = 0.001 | |
print(f"\nHyperparameters:") | |
print(f"- Epochs: {num_epochs}") | |
print(f"- Batch Size: {batch_size}") | |
print(f"- Learning Rate: {learning_rate}") | |
# MNIST dataset | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
print("\nLoading MNIST dataset...") | |
start_time = time.time() | |
# Load training data | |
train_dataset = torchvision.datasets.MNIST( | |
root='./data', | |
train=True, | |
transform=transform, | |
download=True | |
) | |
print(f"Training dataset loaded: {len(train_dataset)} images") | |
# Load test data | |
test_dataset = torchvision.datasets.MNIST( | |
root='./data', | |
train=False, | |
transform=transform, | |
download=True | |
) | |
print(f"Test dataset loaded: {len(test_dataset)} images") | |
print(f"Dataset loading time: {time.time() - start_time:.2f} seconds") | |
# Data loaders | |
print("\nCreating data loaders...") | |
train_loader = DataLoader( | |
dataset=train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=4 if device.type == 'cuda' else 0 | |
) | |
test_loader = DataLoader( | |
dataset=test_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4 if device.type == 'cuda' else 0 | |
) | |
print(f"Train batches: {len(train_loader)}") | |
print(f"Test batches: {len(test_loader)}") | |
# Define the model | |
class NeuralNet(nn.Module): | |
def __init__(self): | |
super(NeuralNet, self).__init__() | |
self.flatten = nn.Flatten() | |
self.linear_stack = nn.Sequential( | |
nn.Linear(28*28, 128), | |
nn.ReLU(), | |
nn.Linear(128, 64), | |
nn.ReLU(), | |
nn.Linear(64, 10) | |
) | |
def forward(self, x): | |
x = self.flatten(x) | |
logits = self.linear_stack(x) | |
return logits | |
print("\nInitializing neural network...") | |
model = NeuralNet().to(device) | |
print(f"Model architecture:") | |
print(model) | |
# Total parameters | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"Total parameters: {total_params:,}") | |
print(f"Trainable parameters: {trainable_params:,}") | |
# Loss and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
# Training loop | |
print("\nStarting training...") | |
total_steps = len(train_loader) | |
training_start = time.time() | |
for epoch in range(num_epochs): | |
epoch_start = time.time() | |
running_loss = 0.0 | |
for i, (images, labels) in enumerate(train_loader): | |
# Move tensors to the configured device | |
images = images.to(device) | |
labels = labels.to(device) | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
# Backward and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
if (i+1) % 100 == 0: | |
avg_loss = running_loss / 100 | |
running_loss = 0.0 | |
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], Loss: {avg_loss:.4f}') | |
epoch_time = time.time() - epoch_start | |
print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds") | |
training_time = time.time() - training_start | |
print(f"\nTraining completed in {training_time:.2f} seconds") | |
# Test the model | |
print("\nEvaluating model on test data...") | |
test_start = time.time() | |
with torch.no_grad(): | |
n_correct = 0 | |
n_samples = 0 | |
class_correct = [0 for _ in range(10)] | |
class_total = [0 for _ in range(10)] | |
for images, labels in test_loader: | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
_, predicted = torch.max(outputs.data, 1) | |
n_samples += labels.size(0) | |
n_correct += (predicted == labels).sum().item() | |
# Class accuracy | |
correct = (predicted == labels).squeeze() | |
for i in range(len(labels)): | |
label = labels[i] | |
class_correct[label] += correct[i].item() | |
class_total[label] += 1 | |
# Overall accuracy | |
acc = 100.0 * n_correct / n_samples | |
print(f'Overall Accuracy: {acc:.2f}%') | |
# Per-class accuracy | |
print("\nPer-class Accuracy:") | |
for i in range(10): | |
class_acc = 100.0 * class_correct[i] / class_total[i] | |
print(f'Digit {i}: {class_acc:.2f}%') | |
test_time = time.time() - test_start | |
print(f"\nEvaluation completed in {test_time:.2f} seconds") | |
# Save the model | |
model_path = 'mnist_model.pth' | |
torch.save(model.state_dict(), model_path) | |
print(f"\nModel saved to {os.path.abspath(model_path)}") | |
print("\nTraining summary:") | |
print(f"- Total runtime: {training_time + test_time:.2f} seconds") | |
print(f"- Training time: {training_time:.2f} seconds") | |
print(f"- Evaluation time: {test_time:.2f} seconds") | |
print(f"- Final accuracy: {acc:.2f}%") | |
print(f"- Dataset: MNIST ({len(train_dataset)} training, {len(test_dataset)} test)") | |
print(f"- Model parameters: {total_params:,}") | |
print(f"- Device used: {device}") | |
print("\nTraining completed successfully!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment