Last active
May 30, 2018 10:17
-
-
Save ilkarman/37a4d5f44f25a4e023572a954a6b258f to your computer and use it in GitHub Desktop.
Chainer multi-node training on Azure BatchAI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
import os | |
from os import path | |
import numpy as np | |
import pandas as pd | |
import multiprocessing | |
import random | |
from toolz import pipe | |
from timer import Timer | |
from PIL import Image | |
from chainercv import transforms | |
import chainer | |
import chainer.cuda | |
from chainer import training | |
from chainer.training import extensions | |
import resnet50 | |
from mpi4py import MPI | |
import chainermn | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Distributed training settings | |
parser = argparse.ArgumentParser( | |
description='Chainer ResNet Example') | |
parser.add_argument('--communicator', default='hierarchical') | |
_WIDTH = 224 | |
_HEIGHT = 224 | |
_LR = 0.001 | |
_EPOCHS = 1 | |
_BATCHSIZE = 64 | |
_IMAGENET_RGB_MEAN_CAFFE = np.array([123.68, 116.78, 103.94], dtype=np.float32) | |
_IMAGENET_SCALE_FACTOR_CAFFE = 0.017 | |
args = parser.parse_args() | |
def _append_path_to(data_path, data_series): | |
return data_series.apply(lambda x: path.join(data_path, x)) | |
def _load_training(data_dir): | |
train_df = pd.read_csv(path.join(data_dir, 'train.csv')) | |
return train_df.assign(filenames=_append_path_to(path.join(data_dir, 'train'), | |
train_df.filenames)) | |
def _load_validation(data_dir): | |
train_df = pd.read_csv(path.join(data_dir, 'validation.csv')) | |
return train_df.assign(filenames=_append_path_to(path.join(data_dir, 'validation'), | |
train_df.filenames)) | |
def _create_data_fn(train_path, test_path): | |
logger.info('Reading training data info') | |
train_df = _load_training(train_path) | |
logger.info('Reading validation data info') | |
validation_df = _load_validation(test_path) | |
# File-path | |
train_X = train_df['filenames'].values | |
validation_X = validation_df['filenames'].values | |
# One-hot encoded labels for torch | |
train_labels = train_df[['num_id']].values.ravel() | |
validation_labels = validation_df[['num_id']].values.ravel() | |
# Index starts from 0 | |
train_labels -= 1 | |
validation_labels -= 1 | |
return train_X, train_labels, validation_X, validation_labels | |
class ImageNet(chainer.dataset.DatasetMixin): | |
def __init__(self, img_locs, labels, augmentation=None): | |
self.img_locs, self.labels = img_locs, labels | |
self.augmentation = augmentation | |
self.imagenet_mean = _IMAGENET_RGB_MEAN_CAFFE | |
self.imagenet_scaling = _IMAGENET_SCALE_FACTOR_CAFFE | |
logger.info("Loaded {} labels and {} images".format(len(self.labels), len(self.img_locs))) | |
def __len__(self): | |
return len(self.img_locs) | |
def get_example(self, idx): | |
im_file = self.img_locs[idx] | |
# RGB Image | |
im_rgb = Image.open(im_file) | |
im_rgb = im_rgb.convert('RGB') | |
im_rgb = self._apply_data_preprocessing(im_rgb) | |
label = self.labels[idx] | |
if self.augmentation is not None: | |
im_rgb = self._apply_data_augmentation(im_rgb) | |
else: | |
im_rgb = transforms.resize(im_rgb, size=(_HEIGHT, _WIDTH)) | |
return np.array(im_rgb, dtype=np.float32), \ | |
np.array(label, dtype=np.int32) | |
def _apply_data_preprocessing(self, rgb_im): | |
# Array | |
im = np.asarray(rgb_im, dtype=np.float32) | |
# (w, h, c) to (c, h, w) | |
im = im.transpose(2, 0, 1) | |
# Caffe normalisation | |
im -= self.imagenet_mean[:, None, None] | |
im *= self.imagenet_scaling | |
return im | |
def _apply_data_augmentation(self, im): | |
im = transforms.random_crop(im, size=(_HEIGHT, _WIDTH)) | |
im = transforms.random_flip(im) | |
return im | |
class TestModeEvaluator(extensions.Evaluator): | |
def evaluate(self): | |
model = self.get_target('main') | |
model.train = False | |
ret = super(TestModeEvaluator, self).evaluate() | |
model.train = True | |
return ret | |
def main(): | |
# Prepare ChainerMN communicator. | |
comm = chainermn.create_communicator(args.communicator) | |
device = comm.intra_rank | |
if comm.mpi_comm.rank == 0: | |
print('==========================================') | |
print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size())) | |
print('Using {} communicator'.format(args.communicator)) | |
print('Num Minibatch-size: {}'.format(_BATCHSIZE)) | |
print('Num epoch: {}'.format(_EPOCHS)) | |
print('==========================================') | |
model = resnet50.ResNet50() | |
if device >= 0: | |
chainer.cuda.get_device(device).use() | |
model.to_gpu() | |
# Create a multi node optimizer from a standard Chainer optimizer. | |
optimizer = chainermn.create_multi_node_optimizer( | |
chainer.optimizers.MomentumSGD(lr=_LR, momentum=0.9), comm) | |
optimizer.setup(model) | |
# Split and distribute the dataset. Only worker 0 loads the whole dataset. | |
# Datasets of worker 0 are evenly split and distributed to all workers. | |
if comm.rank == 0: | |
train_X, train_y, valid_X, valid_y = _create_data_fn(os.getenv('AZ_BATCHAI_INPUT_TRAIN'), | |
os.getenv('AZ_BATCHAI_INPUT_TEST')) | |
# For now some size issue for random-crop | |
train = ImageNet(train_X, train_y) | |
val = ImageNet(valid_X, valid_y) | |
else: | |
train = None | |
val = None | |
train = chainermn.scatter_dataset(train, comm, shuffle=True) | |
val = chainermn.scatter_dataset(val, comm) | |
# Check if chainer.iterators.MultiprocessIterator can be used | |
#train_iter = chainer.iterators.SerialIterator(train, _BATCHSIZE) | |
#val_iter = chainer.iterators.SerialIterator(val, _BATCHSIZE, repeat=False) | |
#multiprocessing.set_start_method('forkserver') | |
train_iter = chainer.iterators.MultiprocessIterator(train, _BATCHSIZE, n_processes=24) | |
val_iter = chainer.iterators.MultiprocessIterator(val, _BATCHSIZE, repeat=False, n_processes=24) | |
# Set up a trainer | |
updater = training.StandardUpdater(train_iter, optimizer, device=device) | |
trainer = training.Trainer(updater, (_EPOCHS, 'epoch')) | |
# No checkpointing temp | |
val_interval = (1, 'epoch') | |
log_interval = (1, 'epoch') | |
# Create a multi node evaluator from an evaluator. | |
evaluator = TestModeEvaluator(val_iter, model, device=device) | |
evaluator = chainermn.create_multi_node_evaluator(evaluator, comm) | |
trainer.extend(evaluator, trigger=val_interval) | |
# Some display and output extensions are necessary only for one worker. | |
# (Otherwise, there would just be repeated outputs.) | |
if comm.rank == 0: | |
trainer.extend(extensions.dump_graph('main/loss')) | |
trainer.extend(extensions.LogReport(trigger=log_interval)) | |
trainer.extend(extensions.observe_lr(), trigger=log_interval) | |
trainer.extend(extensions.PrintReport([ | |
'epoch', 'iteration', 'main/loss', 'validation/main/loss', | |
'main/accuracy', 'validation/main/accuracy', 'elapsed_time' | |
]), trigger=log_interval) | |
trainer.extend(extensions.ProgressBar(update_interval=10)) | |
trainer.run() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment