Created
December 3, 2020 19:01
-
-
Save talesa/78328c2db3ba8697729baa18dacc612d to your computer and use it in GitHub Desktop.
simple requeueable slurm job
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from pathlib import Path | |
import shutil | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.tensorboard | |
import torchvision | |
import torchvision.transforms as transforms | |
import matplotlib.pyplot as plt | |
from numpy import random | |
import matplotlib | |
matplotlib.use('agg') | |
# **kwargs below is used to allow additional unused keywords to be passed to the train function | |
def train(lr, n_epochs, logdir=None, checkpoint=None, **kwargs): | |
print(f'We are using learning rate {lr}.') | |
print(f'We are using n_epochs {n_epochs}.') | |
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | |
transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
) | |
trainset = torchvision.datasets.CIFAR10( | |
# Try to use the datasets in /data/localhost/not-backed-up/datasets-ziz-all | |
# See [URL with discussion about how we organize datasets, TBC] for details. | |
root="/data/localhost/not-backed-up/datasets-ziz-all/torchvision/CIFAR10", | |
train=True, download=False, transform=transform | |
) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True) | |
testset = torchvision.datasets.CIFAR10( | |
root="/data/localhost/not-backed-up/datasets-ziz-all/torchvision/CIFAR10", | |
train=False, download=False, transform=transform | |
) | |
testloader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=False) | |
net = Net() | |
net.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9) | |
if logdir: | |
# If logdir doesn't exist we create it. | |
# We expect logdir to be an absolute path, somewhere in /data/ziz/not-backed-up/scratch/$USER. | |
Path(logdir).mkdir(parents=True, exist_ok=True) | |
# logdir is on a network-filesystem (NFS), a drive that is shared from ziz to the compute nodes zizgpu0x, | |
# so it's slow to read or write and we don't want to write large files (e.g. checkpoints) or too often to that | |
# directory. | |
# However, it is a convenient place to put tensorboard logs (just the scalar metrics, not images etc) at because | |
# that will allow us monitoring the progress of all of our experiments, across all compute nodes zizgpu0x, | |
# by running just a single tensorboard service on ziz. | |
writer = torch.utils.tensorboard.SummaryWriter(log_dir=logdir) | |
if checkpoint: | |
# this will handle both relative and absolute paths | |
checkpoint_path = Path(logdir, checkpoint) | |
checkpoint = torch.load(checkpoint_path) | |
net.load_state_dict(checkpoint['net.state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer.state_dict']) | |
# if checkpoint['epoch']==i that means the last checkpoint was made at the end of epoch i | |
# so we restart from i+1 | |
start_epoch = checkpoint['epoch'] + 1 | |
print(f'Restarting training from epoch {start_epoch} of checkpoint: {checkpoint_path.absolute()}') | |
else: | |
start_epoch = 0 | |
num_batches = len(trainloader) | |
for epoch in range(start_epoch, n_epochs): # loop over the dataset multiple times | |
running_loss = 0.0 | |
for i, data in enumerate(trainloader, 0): | |
# get the inputs; data is a list of [inputs, labels] | |
inputs, labels = data[0].to(device), data[1].to(device) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward + backward + optimize | |
outputs = net(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
if logdir: | |
# Log the value of the loss. | |
writer.add_scalar('loss/train', loss.item(), i) | |
# print statistics | |
running_loss += loss.item() | |
if i % 1000 == 0 and i != 0: # print every 1000 mini-batches | |
print(f"Epoch: {epoch}. Steps: {i}/{num_batches}. Loss: {running_loss/2000}") | |
running_loss = 0.0 | |
if logdir: | |
# Let's save a checkpoint after every epoch. | |
torch.save({'net.state_dict': net.state_dict(), | |
'optimizer.state_dict': optimizer.state_dict(), | |
'epoch': epoch,}, | |
Path(logdir, 'latest_checkpoint.torch')) | |
# At the end of the training evaluate the accuracy on the test set and save it to the logdir/results. | |
correct = 0 | |
total = 0 | |
# We don't need to compute gradients evaluating the performance on the test set. | |
with torch.no_grad(): | |
for data in testloader: | |
inputs, labels = data[0].to(device), data[1].to(device) | |
outputs = net(inputs) | |
predicted = torch.argmax(outputs.data, dim=1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print('Accuracy on the test set: %d%%' % (100 * correct / total)) | |
if logdir: | |
# Save the results to logdir, on the central storage. | |
torch.save({'accuracy': correct / total}, Path(logdir, 'results.torch')) | |
print("Finished training.") | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(3, 6, 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(6, 16, 5) | |
self.fc1 = nn.Linear(16 * 5 * 5, 120) | |
self.fc2 = nn.Linear(120, 84) | |
self.fc3 = nn.Linear(84, 10) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = x.view(-1, 16 * 5 * 5) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--lr', type=float) | |
parser.add_argument('--n_epochs', type=int) | |
parser.add_argument('--checkpoint', type=str, default=None, | |
help="Absolute or relative (wrt to the specified --logdir) path to the checkpoint.") | |
# On our cluster this would be intended to be on one of the network-filesystem (NFS) drives from ziz, | |
# which is available on all zizgpu0x, so at /data/ziz/not-backed-up/scratch/$USER | |
parser.add_argument('--logdir', type=str, default=None, | |
help="Absolute or relative (wrt to working directory) log directory.") | |
args, unknown_args = parser.parse_known_args() | |
print(f"Unrecognized args: {unknown_args}") | |
train(**args.__dict__) | |
# Equivalent to | |
# model.train(lr=args.lr, n_epochs=args.n_epochs) | |
# model.train(lr=args.__dict__['lr'], n_epochs=args.__dict__['n_epochs']) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# This script is a working example which allows you to do `scontrol requeue JOBID`. | |
# This example builds on top of your understanding of `single_training.sh`. | |
# Usage: `sbatch --output=/data/ziz/not-backed-up/scratch/$USER/slurm-%j.o --error=/data/ziz/not-backed-up/scratch/$USER/slurm-%j.o /data/ziz/not-backed-up/software/ziz_toolkit/slurm_gpu_examples/example_basic/simple_requeueable_job.sh` | |
#SBATCH --job-name=simple_requeueable_job | |
#SBATCH --partition=ziz-gpu | |
#SBATCH --gres=gpu:1 | |
#SBATCH --cpus-per-task=1 | |
#SBATCH --time=14-00:00:00 | |
#SBATCH --mem=5G | |
#SBATCH --ntasks=1 | |
# THE PART NEW IN simple_requeueable_job.sh | |
# This allows your job to be requeued. | |
#SBATCH --requeue | |
# The setting makes sure that once your job is restarted it doesn't overwrite the --output and --error logs from before | |
# the restart, but just appends to them. | |
#SBATCH --open-mode=append | |
export PATH_TO_CONDA="/data/ziz/not-backed-up/software/ziz_toolkit/miniconda3" | |
# Activate conda virtual environment | |
source $PATH_TO_CONDA/bin/activate example_environment | |
# Just to make sure the directories exists | |
mkdir -p /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs | |
# If it's a restart adds "--checkpoint ..." to the python command. | |
if [[ $SLURM_RESTART_COUNT -gt 0 ]]; then | |
echo "Restarting count: $SLURM_RESTART_COUNT" | |
export CHECKPOINT="--checkpoint /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs/latest_checkpoint.torch" | |
fi | |
echo "python | |
/data/ziz/not-backed-up/software/ziz_toolkit/slurm_gpu_examples/example_basic/model.py | |
--lr 0.02 | |
--n_epochs 10 | |
--logdir /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs | |
$CHECKPOINT" | |
python -u /data/ziz/not-backed-up/software/ziz_toolkit/slurm_gpu_examples/example_basic/model.py \ | |
--lr 0.02 \ | |
--n_epochs 10 \ | |
--logdir /data/ziz/not-backed-up/scratch/$USER/ziz_toolkit/slurm_gpu_examples/example_basic/logs \ | |
$CHECKPOINT | |
echo "Job completed." |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment