Created
October 8, 2020 11:27
-
-
Save ilkarman/1c057ca1843ffc1b07779c9a3b94848f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import shutil | |
import time | |
import socket | |
import multiprocessing | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.distributed as dist | |
import torch.optim | |
import torch.utils.data | |
import torch.utils.data.distributed | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
import torchvision.models as models | |
from torch.cuda.amp import autocast, GradScaler | |
# Nvidia's own version of default_collate from pytorch; instead of calling transforms.ToTensor() | |
def fast_collate(batch, memory_format): | |
imgs = [img[0] for img in batch] | |
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) | |
w = imgs[0].size[0] | |
h = imgs[0].size[1] | |
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format) | |
for i, img in enumerate(imgs): | |
nump_array = np.asarray(img, dtype=np.uint8) | |
if(nump_array.ndim < 3): | |
nump_array = np.expand_dims(nump_array, axis=-1) | |
nump_array = np.rollaxis(nump_array, 2) | |
tensor[i] += torch.from_numpy(nump_array) | |
return tensor, targets | |
def parse(): | |
model_names = sorted( | |
name | |
for name in models.__dict__ | |
if name.islower() | |
and not name.startswith("__") | |
and callable(models.__dict__[name]) | |
) | |
parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") | |
parser.add_argument("data", metavar="DIR", default="", help="path to dataset") | |
parser.add_argument( | |
"--arch", | |
"-a", | |
metavar="ARCH", | |
default="resnet50", | |
choices=model_names, | |
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", | |
) | |
parser.add_argument( | |
"-j", | |
"--workers", | |
default=4, | |
type=int, | |
metavar="N", | |
help="number of data loading workers per process/GPU (default: 4)", | |
) | |
parser.add_argument( | |
"--epochs", | |
default=5, | |
type=int, | |
metavar="N", | |
help="number of total epochs to run", | |
) | |
parser.add_argument( | |
"--start-epoch", | |
default=0, | |
type=int, | |
metavar="N", | |
help="manual epoch number (useful on restarts)", | |
) | |
parser.add_argument( | |
"-b", | |
"--batch-size", | |
default=256, | |
type=int, | |
metavar="N", | |
help="mini-batch size per process/GPU (default: 256)", | |
) | |
parser.add_argument( | |
"--lr", | |
"--learning-rate", | |
default=0.1, | |
type=float, | |
metavar="LR", | |
help="Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(" | |
"args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.", | |
) | |
parser.add_argument( | |
"--momentum", default=0.9, type=float, metavar="M", help="momentum" | |
) | |
parser.add_argument( | |
"--weight-decay", | |
"--wd", | |
default=1e-4, | |
type=float, | |
metavar="W", | |
help="weight decay (default: 1e-4)", | |
) | |
parser.add_argument( | |
"--print-freq", | |
"-p", | |
default=10, | |
type=int, | |
metavar="N", | |
help="print frequency (default: 10)", | |
) | |
parser.add_argument( | |
"--resume", | |
default="", | |
type=str, | |
metavar="PATH", | |
help="path to latest checkpoint (default: none)", | |
) | |
parser.add_argument( | |
"-e", | |
"--evaluate", | |
dest="evaluate", | |
action="store_true", | |
help="evaluate model on validation set", | |
) | |
parser.add_argument( | |
"--pretrained", | |
dest="pretrained", | |
action="store_true", | |
help="use pre-trained model", | |
) | |
parser.add_argument( | |
"--prof", default=-1, type=int, help="Only run 10 iterations for profiling." | |
) | |
parser.add_argument("--deterministic", action="store_true") | |
parser.add_argument("--local_rank", default=0, type=int) | |
parser.add_argument("--channels-last", type=bool, default=True) | |
parser.add_argument("--synthetic", action="store_true", help="Run on fake-data") | |
parser.add_argument( | |
"--validate", action="store_true", help="Run validation during training-loop" | |
) | |
args = parser.parse_args() | |
return args | |
def main(): | |
global best_prec1, args | |
args = parse() | |
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) | |
print("\nNCCL VERSION: {}\n".format(torch.cuda.nccl.version())) | |
print("\nCPU Count: {}\n".format(multiprocessing.cpu_count())) | |
cudnn.benchmark = True | |
best_prec1 = 0 | |
ngpus_per_node = torch.cuda.device_count() | |
if args.deterministic: | |
cudnn.benchmark = False | |
cudnn.deterministic = True | |
torch.manual_seed(args.local_rank) | |
torch.set_printoptions(precision=10) | |
args.distributed = False | |
if "WORLD_SIZE" in os.environ: | |
args.distributed = int(os.environ["WORLD_SIZE"]) > 1 | |
args.gpu = 0 | |
args.world_size = 1 | |
if args.distributed: | |
print("Local rank {}".format(args.local_rank)) | |
args.gpu = args.local_rank | |
torch.cuda.set_device(args.gpu) | |
print("Use GPU: {} for training".format(args.gpu)) | |
torch.distributed.init_process_group( | |
backend="nccl", | |
init_method="env://" | |
) | |
args.world_size = torch.distributed.get_world_size() | |
if args.channels_last: | |
memory_format = torch.channels_last | |
else: | |
memory_format = torch.contiguous_format | |
# create model | |
if args.pretrained: | |
print("=> using pre-trained model '{}'".format(args.arch)) | |
model = models.__dict__[args.arch](pretrained=True) | |
else: | |
print("=> creating model '{}'".format(args.arch)) | |
model = models.__dict__[args.arch]() | |
# Scale learning rate based on global batch size | |
args.lr = args.lr * float(args.batch_size * args.world_size) / 256.0 | |
optimizer = torch.optim.SGD( | |
model.parameters(), | |
args.lr, | |
momentum=args.momentum, | |
weight_decay=args.weight_decay, | |
) | |
if args.distributed: | |
if args.gpu is not None: | |
torch.cuda.set_device(args.gpu) | |
model.cuda(args.gpu) | |
model=model.to(memory_format=memory_format) | |
print("Batch per GPU: {}".format(args.batch_size)) | |
print("Total Batch on Node: {}".format(args.batch_size * ngpus_per_node)) | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) | |
else: | |
model.cuda() | |
model = model.to(memory_format=memory_format) | |
model = torch.nn.parallel.DistributedDataParallel(model) | |
else: | |
torch.cuda.set_device(args.gpu) | |
model = model.cuda(args.gpu) | |
model = model.to(memory_format=memory_format) | |
# define loss function (criterion) and optimizer | |
criterion = nn.CrossEntropyLoss().cuda() | |
# Data loading code | |
crop_size = 224 | |
val_size = 256 | |
traindir = os.path.join(args.data, "train") | |
train_dataset = datasets.ImageFolder( | |
traindir, | |
transforms.Compose( | |
[ | |
transforms.RandomResizedCrop(crop_size), | |
transforms.RandomHorizontalFlip() | |
] | |
), | |
) | |
print("Train dataset size: {}".format(len(train_dataset))) | |
if args.validate: | |
valdir = os.path.join(args.data, "val") | |
val_dataset = datasets.ImageFolder( | |
valdir, | |
transforms.Compose( | |
[ | |
transforms.Resize(val_size), | |
transforms.CenterCrop(crop_size) | |
] | |
), | |
) | |
train_sampler = None | |
val_sampler = None | |
if args.distributed: | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | |
if args.validate: | |
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) | |
collate_fn = lambda b: fast_collate(b, memory_format) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=args.batch_size, | |
shuffle=(train_sampler is None), | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=train_sampler, | |
collate_fn=collate_fn, | |
) | |
if args.validate: | |
val_loader = torch.utils.data.DataLoader( | |
val_dataset, | |
batch_size=args.batch_size, | |
shuffle=False, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=val_sampler, | |
collate_fn=collate_fn, | |
) | |
if args.evaluate: | |
validate(val_loader, model, criterion) | |
return | |
print("Starting training") | |
stime = time.time() | |
for epoch in range(args.start_epoch, args.epochs): | |
if args.distributed: | |
train_sampler.set_epoch(epoch) | |
stime_epoch = time.time() | |
train(train_loader, model, criterion, optimizer, epoch) | |
# Validate after full-training | |
if args.validate: | |
# evaluate on validation set | |
prec1 = validate(val_loader, model, criterion) | |
# remember best prec@1 and save checkpoint | |
if args.local_rank == 0: | |
is_best = prec1 > best_prec1 | |
best_prec1 = max(prec1, best_prec1) | |
save_checkpoint( | |
{ | |
"epoch": epoch + 1, | |
"arch": args.arch, | |
"state_dict": model.state_dict(), | |
"best_prec1": best_prec1, | |
"optimizer": optimizer.state_dict(), | |
}, | |
is_best, | |
) | |
class data_prefetcher(): | |
def __init__(self, loader): | |
self.loader = iter(loader) | |
self.stream = torch.cuda.Stream() | |
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) | |
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) | |
self.preload() | |
def preload(self): | |
try: | |
self.next_input, self.next_target = next(self.loader) | |
except StopIteration: | |
self.next_input = None | |
self.next_target = None | |
return | |
with torch.cuda.stream(self.stream): | |
self.next_input = self.next_input.cuda(non_blocking=True) | |
self.next_target = self.next_target.cuda(non_blocking=True) | |
self.next_input = self.next_input.float() | |
self.next_input = self.next_input.sub_(self.mean).div_(self.std) | |
def next(self): | |
torch.cuda.current_stream().wait_stream(self.stream) | |
input = self.next_input | |
target = self.next_target | |
if input is not None: | |
input.record_stream(torch.cuda.current_stream()) | |
if target is not None: | |
target.record_stream(torch.cuda.current_stream()) | |
self.preload() | |
return input, target | |
def train(train_loader, model, criterion, optimizer, epoch): | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
# switch to train mode | |
model.train() | |
end = time.time() | |
prefetcher = data_prefetcher(train_loader) | |
input, target = prefetcher.next() | |
scaler = GradScaler() | |
i = 0 | |
while input is not None: | |
i += 1 | |
adjust_learning_rate(optimizer, epoch, i, len(train_loader)) | |
with autocast(): | |
output = model(input) | |
loss = criterion(output, target) | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
optimizer.zero_grad() | |
if i % args.print_freq == 0: | |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
# Average loss and accuracy across processes for logging | |
if args.distributed: | |
reduced_loss = reduce_tensor(loss.data) | |
prec1 = reduce_tensor(prec1) | |
prec5 = reduce_tensor(prec5) | |
else: | |
reduced_loss = loss.data | |
# to_python_float incurs a host<->device sync | |
losses.update(to_python_float(reduced_loss), input.size(0)) | |
top1.update(to_python_float(prec1), input.size(0)) | |
top5.update(to_python_float(prec5), input.size(0)) | |
torch.cuda.synchronize() | |
batch_time.update((time.time() - end)/args.print_freq) | |
end = time.time() | |
if args.local_rank == 0: | |
print('Epoch: [{0}][{1}/{2}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Speed {3:.3f} ({4:.3f})\t' | |
'Loss {loss.val:.10f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | |
epoch, i, len(train_loader), | |
args.world_size*args.batch_size/batch_time.val, | |
args.world_size*args.batch_size/batch_time.avg, | |
batch_time=batch_time, | |
loss=losses, top1=top1, top5=top5)) | |
input, target = prefetcher.next() | |
def validate(val_loader, model, criterion): | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
# switch to evaluate mode | |
model.eval() | |
end = time.time() | |
prefetcher = data_prefetcher(val_loader) | |
input, target = prefetcher.next() | |
i = 0 | |
while input is not None: | |
i += 1 | |
# compute output | |
with torch.no_grad(): | |
output = model(input) | |
loss = criterion(output, target) | |
# measure accuracy and record loss | |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
if args.distributed: | |
reduced_loss = reduce_tensor(loss.data) | |
prec1 = reduce_tensor(prec1) | |
prec5 = reduce_tensor(prec5) | |
else: | |
reduced_loss = loss.data | |
losses.update(to_python_float(reduced_loss), input.size(0)) | |
top1.update(to_python_float(prec1), input.size(0)) | |
top5.update(to_python_float(prec5), input.size(0)) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
# TODO: Change timings to mirror train(). | |
if args.local_rank == 0 and i % args.print_freq == 0: | |
print('Test: [{0}/{1}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Speed {2:.3f} ({3:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | |
i, len(val_loader), | |
args.world_size * args.batch_size / batch_time.val, | |
args.world_size * args.batch_size / batch_time.avg, | |
batch_time=batch_time, loss=losses, | |
top1=top1, top5=top5)) | |
input, target = prefetcher.next() | |
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' | |
.format(top1=top1, top5=top5)) | |
return top1.avg | |
def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): | |
torch.save(state, filename) | |
if is_best: | |
shutil.copyfile(filename, "model_best.pth.tar") | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def adjust_learning_rate(optimizer, epoch, step, len_epoch): | |
"""LR schedule that should yield 76% converged accuracy with batch size 256""" | |
factor = epoch // 30 | |
if epoch >= 80: | |
factor = factor + 1 | |
lr = args.lr*(0.1**factor) | |
"""Warmup""" | |
if epoch < 5: | |
lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) | |
# if(args.local_rank == 0): | |
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the precision@k for the specified values of k""" | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
def reduce_tensor(tensor): | |
rt = tensor.clone() | |
dist.all_reduce(rt, op=dist.reduce_op.SUM) | |
rt /= args.world_size | |
return rt | |
def to_python_float(t): | |
if hasattr(t, 'item'): | |
return t.item() | |
else: | |
return t[0] | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment