Last active
June 30, 2018 17:37
-
-
Save kenfehling/759d1a7f7c3bd7ebca4c6c9e5b86e11f to your computer and use it in GitHub Desktop.
MXNet vs. PyTorch benchmark, comparing imperative vs. symbolic and single vs. half precision
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn as ptnn | |
from torch.autograd import Variable | |
import mxnet as mx | |
from mxnet.gluon import nn as mxnn | |
from mxnet import nd, initializer | |
from enum import IntEnum | |
from time import time | |
use_cuda = torch.cuda.is_available() | |
fmt = ' {:<14} {:<15} {:<12} {:>5}' | |
class Framework(IntEnum): | |
PYTORCH = 1 | |
MXNET = 2 | |
def get_mxnet_network(): | |
net = mxnn.HybridSequential() | |
with net.name_scope(): | |
net.add(mxnn.Dense(256, activation="relu")) | |
net.add(mxnn.Dense(128, activation="relu")) | |
net.add(mxnn.Dense(2)) | |
net.collect_params().initialize(init=initializer.Zero()) | |
return net | |
def pytorch_weights_init(m): | |
if isinstance(m, ptnn.Linear): | |
ptnn.init.uniform_(m.weight.data, 0, 0) | |
ptnn.init.uniform_(m.bias.data, 0, 0) | |
def get_pytorch_network(): | |
net = ptnn.Sequential() | |
net.add_module('dense1', ptnn.Linear(1, 256)) | |
net.add_module('relu1', ptnn.ReLU()) | |
net.add_module('dense2', ptnn.Linear(256, 128)) | |
net.add_module('relu2', ptnn.ReLU()) | |
net.add_module('dense3', ptnn.Linear(128, 2)) | |
net.apply(pytorch_weights_init) | |
return net | |
# Wait for computation to finish to make profiling more accurate | |
def block(framework): | |
if framework == Framework.PYTORCH: | |
if use_cuda: | |
torch.cuda.synchronize() | |
elif framework == Framework.MXNET: | |
mx.nd.waitall() | |
def bench(net, x, framework): | |
block(framework) | |
start = time() | |
for i in range(1000): | |
y = net(x) | |
block(framework) | |
return time() - start | |
def report(framework, paradigm, precision, value=None): | |
t = '%i' % (value * 1000) if value else '---' | |
print(fmt.format(framework, paradigm, '%i bit' % precision, t)) | |
# Input matrices | |
mx_x_32 = nd.ones((512, 1)) | |
mx_x_16 = mx_x_32.astype('float16') | |
pt_x_32 = Variable(torch.ones((512, 1))) | |
pt_x_16 = pt_x_32.half() | |
print() | |
print(' Device:', 'GPU' if use_cuda else 'CPU') | |
print('----------------------------------------------------') | |
print(fmt.format('Framework', 'Paradigm', 'Precision', 'Time')) | |
print('====================================================') | |
mx_net = get_mxnet_network() | |
report('MXNet', 'imperative', 32, bench(mx_net, mx_x_32, Framework.MXNET)) | |
mx_net.cast('float16') | |
report('MXNet', 'imperative', 16, bench(mx_net, mx_x_16, Framework.MXNET)) | |
mx_net.cast('float32') | |
mx_net.hybridize() | |
report('MXNet', 'symbolic', 32, bench(mx_net, mx_x_32, Framework.MXNET)) | |
mx_net.cast('float16') | |
report('MXNet', 'symbolic', 16, bench(mx_net, mx_x_16, Framework.MXNET)) | |
pt_net = get_pytorch_network() | |
report('PyTorch', 'imperative', 32, bench(pt_net, pt_x_32, Framework.PYTORCH)) | |
# PyTorch half precision isn't supported on a CPU | |
pt_16 = bench(pt_net.half(), pt_x_16, Framework.PYTORCH) if use_cuda else None | |
report('PyTorch', 'imperative', 16, pt_16) | |
print('----------------------------------------------------') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Results on my MacBook Pro
Processor: 2.5 GHz Intel Core i7