Last active
April 28, 2018 04:19
-
-
Save zhreshold/0638208984ac0d301a9c89423b976904 to your computer and use it in GitHub Desktop.
Train imagenet using gluon
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse, time | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
fh = logging.FileHandler('training.log') | |
logger = logging.getLogger() | |
logger.addHandler(fh) | |
import mxnet as mx | |
from mxnet import gluon | |
from mxnet.gluon import nn | |
from mxnet.gluon.model_zoo import vision | |
from mxnet import autograd as ag | |
class DummyIter(mx.io.DataIter): | |
def __init__(self, batch_size, data_shape, batches = 5): | |
super(DummyIter, self).__init__(batch_size) | |
self.data_shape = (batch_size,) + data_shape | |
self.label_shape = (batch_size,) | |
self.provide_data = [('data', self.data_shape)] | |
self.provide_label = [('softmax_label', self.label_shape)] | |
self.batch = mx.io.DataBatch(data=[mx.nd.zeros(self.data_shape)], | |
label=[mx.nd.zeros(self.label_shape)]) | |
self._batches = 0 | |
self.batches = batches | |
def next(self): | |
if self._batches < self.batches: | |
self._batches += 1 | |
return self.batch | |
else: | |
self._batches = 0 | |
raise StopIteration | |
def dummy_iterator(batch_size, data_shape): | |
return DummyIter(batch_size, data_shape), DummyIter(batch_size, data_shape) | |
# CLI | |
parser = argparse.ArgumentParser(description='Train a model for imagenet classification') | |
parser.add_argument('--train-rec', type=str, required=True, | |
help='training record file') | |
parser.add_argument('--val-rec', type=str, required=True, | |
help='validation record file') | |
parser.add_argument('--train-idx', type=str, required=True, | |
help='train index file') | |
parser.add_argument('--gpus', type=int, default=0, | |
help='number of gpus to use') | |
parser.add_argument('--epochs', type=int, default=120, | |
help='number of total epochs') | |
parser.add_argument('--batch-size', type=int, default=256, | |
help='batch size') | |
parser.add_argument('--lr', type=float, default=0.1, | |
help='learning rate') | |
parser.add_argument('--momentum', type=float, default=0.9, | |
help='momentum') | |
parser.add_argument('--wd', type=float, default=1e-4, | |
help='weight decay') | |
parser.add_argument('--start-epoch', type=int, default=0, | |
help='starting epoch') | |
parser.add_argument('--resume', type=str, default='', | |
help='path to checkpoint') | |
parser.add_argument('--seed', type=int, default=123, | |
help='random seed to use. Default=123.') | |
parser.add_argument('--benchmark', action='store_true', | |
help='whether to run benchmark.') | |
parser.add_argument('--mode', type=str, | |
help='mode in which to train the model. options are symbolic, imperative, hybrid') | |
parser.add_argument('--iter', type=str, | |
help='type of iterator to use, cc to use .') | |
parser.add_argument('--model', type=str, required=True, | |
help='type of model to use. see vision_model for options.') | |
parser.add_argument('--use_thumbnail', action='store_true', | |
help='use thumbnail or not in resnet. default is false.') | |
parser.add_argument('--batch-norm', action='store_true', | |
help='enable batch normalization or not in vgg. default is false.') | |
parser.add_argument('--pretrained', action='store_true', | |
help='enable using pretrained model from gluon.') | |
parser.add_argument('--log-interval', type=int, default=50, | |
help='Number of batches to wait before logging.') | |
args = parser.parse_args() | |
logging.info(str(args)) | |
mx.random.seed(args.seed) | |
ctx = [mx.gpu(i) for i in range(args.gpus)] if args.gpus > 0 else [mx.cpu()] | |
kwargs = {'ctx': ctx, 'pretrained': args.pretrained, 'classes': 1000} | |
if args.model.startswith('resnet'): | |
kwargs['thumbnail'] = args.use_thumbnail | |
elif args.model.startswith('vgg'): | |
kwargs['batch_norm'] = args.batch_norm | |
net = vision.get_model(args.model, **kwargs) | |
data_shape = (3, 224, 224) | |
if not args.benchmark: | |
if args.iter == 'cc': | |
train_iter = mx.io.ImageRecordIter(path_imgrec=args.train_rec, data_shape=data_shape, | |
shuffle=True, mean_r=123.68, mean_g=116.28, mean_b=103.53, | |
std_r=58.395, std_g=57.12, std_b=57.375, | |
batch_size=args.batch_size, rand_crop=True, | |
max_crop_size=480, min_crop_size=38, rand_mirror=True) | |
val_iter = mx.io.ImageRecordIter(path_imgrec=args.val_rec, data_shape=data_shape, | |
shuffle=False, mean_r=123.68, mean_g=116.28, mean_b=103.53, | |
std_r=58.395, std_g=57.12, std_b=57.375, | |
batch_size=args.batch_size, rand_crop=False, rand_mirror=False) | |
else: | |
train_iter = mx.image.ImageIter(args.batch_size, data_shape, path_imgrec=args.train_rec, | |
path_imgidx=args.train_idx, shuffle=True, mean=True, | |
std=True, rand_resize=True, rand_crop=True, rand_mirror=True) | |
# train_iter = mx.io.PrefetchingIter(train_iter) | |
val_iter = mx.image.ImageIter(args.batch_size, data_shape, path_imgrec=args.val_rec, | |
shuffle=False, mean=True, std=True, resize=256) | |
# val_iter = mx.io.PrefetchingIter(val_iter) | |
else: | |
train_iter, val_iter = dummy_iterator(args.batch_size, data_shape) | |
def validate(): | |
metric = mx.metric.CompositeEvalMetric(['acc', mx.metric.TopKAccuracy(5)]) | |
val_iter.reset() | |
for batch in val_iter: | |
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | |
outputs = [] | |
for x in data: | |
outputs.append(net(x)) | |
metric.update(label, outputs) | |
return metric.get() | |
def train(epochs, start_epoch): | |
if args.resume: | |
start_epoch = int(args.resume) | |
net.load_params('imagenet-%s-%d.params' % (args.model, start_epoch), ctx=ctx) | |
logging.info('loaded from epoch %d', start_epoch) | |
elif not args.pretrained: | |
net.initialize(mx.init.Xavier(factor_type='out', magnitude=2), ctx=ctx) | |
optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum} | |
trainer = gluon.Trainer(net.collect_params(), 'sgd', optimizer_params) | |
metric = mx.metric.CompositeEvalMetric(['acc', mx.metric.TopKAccuracy(5)]) | |
loss = gluon.loss.SoftmaxCrossEntropyLoss() | |
for epoch in range(start_epoch, epochs): | |
phase = int(start_epoch / 30) | |
if phase > 0: | |
optimizer_params['learning_rate'] = args.lr / (10 ** phase) | |
logging.info('Reduce learning rate to %f', optimizer_params['learning_rate']) | |
elif epoch % 30 == 0 and epoch > 0: | |
optimizer_params['learning_rate'] /= 10. | |
# optimizer_params['learning_rate'] = args.lr / (10 ** (int(epoch / 30))) | |
logging.info('Reduce learning rate to %f', optimizer_params['learning_rate']) | |
trainer = gluon.Trainer(net.collect_params(), 'sgd', optimizer_params) | |
tic = time.time() | |
train_iter.reset() | |
metric.reset() | |
btic = time.time() | |
for i, batch in enumerate(train_iter): | |
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | |
outputs = [] | |
Ls = [] | |
with ag.record(): | |
for x, y in zip(data, label): | |
z = net(x) | |
# L = loss(z, y) | |
L = mx.nd.SoftmaxOutput(z, y) | |
# store the loss and do backward after we have done forward | |
# on all GPUs for better speed on multiple GPUs. | |
Ls.append(L) | |
outputs.append(z) | |
for L in Ls: | |
L.backward() | |
trainer.step(batch.data[0].shape[0]) | |
metric.update(label, outputs) | |
if args.log_interval and not (i+1)%args.log_interval: | |
name, acc = metric.get() | |
logging.info('[Epoch %d Batch %d] speed: %f samples/s, training: %s=%f, %s=%f'%( | |
epoch, i, args.batch_size/(time.time()-btic), name[0], acc[0], name[1], acc[1])) | |
btic = time.time() | |
name, acc = metric.get() | |
logging.info('[Epoch %d] training: %s=%f, %s=%f'%(epoch, name[0], acc[0], name[1], acc[1])) | |
logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic)) | |
name, val_acc = validate() | |
logging.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1])) | |
net.save_params('imagenet-%s-%d.params' % (args.model, epoch + 1)) | |
if __name__ == '__main__': | |
if args.mode == 'symbolic': | |
data = mx.sym.var('data') | |
out = net(data) | |
softmax = mx.sym.SoftmaxOutput(out, name='softmax') | |
mod = mx.mod.Module(softmax, context=ctx) | |
optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum} | |
mod.fit(train_iter, val_iter, num_epoch=args.epochs, | |
batch_end_callback = mx.callback.Speedometer(args.batch_size, 1), | |
optimizer='sgd', | |
optimizer_params=optimizer_params, | |
epoch_end_callback = mx.callback.do_checkpoint(args.model), | |
initializer=mx.init.Xavier(factor_type='out', magnitude=2), | |
eval_metric=['acc', mx.metric.TopKAccuracy(5)], | |
validation_metric=['acc', mx.metric.TopKAccuracy(5)]) | |
else: | |
if args.mode == 'hybrid': | |
net.hybridize() | |
train(args.epochs, args.start_epoch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment