Last active
September 23, 2018 07:03
-
-
Save greydanus/1cb90875f24015660ae91fa637f167a9 to your computer and use it in GitHub Desktop.
Trains an MNIST classifier using Synthetic Gradients. See Google DeepMind paper @ arxiv.org/abs/1608.05343.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Trains an MNIST classifier using Synthetic Gradients. See Google DeepMind paper @ arxiv.org/abs/1608.05343. """ | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
from tensorflow.examples.tutorials.mnist import input_data # just use tensorflow's mnist api | |
mnist = input_data.read_data_sets('MNIST_data', one_hot=False) | |
# hyperparameters | |
global_step = 0 | |
batch_size = 10 | |
lr, slr = 2e-2, 1e-4 # learning rates | |
D = 28*28 # dimensionality | |
h1_size, h2_size = 100, 10 # first and second hidden layer | |
synth_step = 10 # ratio of model updates to synthetic gradient updates | |
# initialize mnist model | |
model = {} | |
model['W1'] = np.random.randn(D,h1_size) / np.sqrt(h1_size) # Xavier initialization: paper @ goo.gl/HYyVa7 | |
model['b1'] = np.random.randn(1,h1_size) / np.sqrt(h1_size) | |
model['W2'] = np.random.randn(h1_size,h2_size) / np.sqrt(h2_size) | |
model['b2'] = np.random.randn(1,h2_size) / np.sqrt(h2_size) | |
# model for predicting synthetic gradients | |
smodel = { k : np.random.randn(v.shape[1], np.prod(v.shape)) * 1e-2 / np.sqrt(np.prod(v.shape)) \ | |
for k,v in model.iteritems() } # the 1e-2 scale factor helps...synthetic grads are weird | |
def softmax(x): | |
maxes = np.amax(x, axis=1, keepdims=True) | |
e = np.exp(x - maxes) # improves numerics | |
return e / np.sum(e, axis=1, keepdims=True) | |
def forward(X, model): | |
# evaluate class scores, [N x K] | |
hs = [] # we'll need the h's for computing gradients | |
h1 = np.dot(X, model['W1']) + model['b1'] ; hs.append(h1) | |
h1[h1<0] = 0 # relu activation | |
h2 = np.dot(h1, model['W2']) + model['b2']; hs.append(h2) | |
h2[h2<0] = 0 # relu activation | |
probs = softmax(h2) | |
return probs, hs | |
def backward(y, probs, hs, model): | |
grads = { k : np.zeros_like(v) for k,v in model.iteritems() } # dict to hold gradients | |
probs[range(batch_size),y] -= 1 # backprop through softmax | |
dh2 = probs / batch_size | |
# second hidden layer | |
grads['W2'] = np.dot(hs[0].T, dh2) | |
grads['b2'] = np.sum(dh2, axis=0, keepdims=True) | |
# first hidden layer | |
dh1 = np.dot(dh2, model['W2'].T) | |
dh1[hs[0] <= 0] = 0 # backprop through relu | |
grads['W1'] = np.dot(X.T, dh1) | |
grads['b1'] = np.sum(dh1, axis=0, keepdims=True) | |
return grads | |
# forward propagate synthetic gradient model | |
def sforward(hs, smodel, model): | |
synthetic_grads = { k : np.zeros_like(v) for k,v in smodel.iteritems() } | |
synthetic_grads['W1'] = np.dot(hs[0], smodel['W1']) | |
synthetic_grads['b1'] = np.dot(hs[0], smodel['b1']) | |
synthetic_grads['W2'] = np.dot(hs[1], smodel['W2']) | |
synthetic_grads['b2'] = np.dot(hs[1], smodel['b2']) | |
return synthetic_grads | |
# backward propagate synthetic gradient model | |
def sbackward(hs, ds, smodel): | |
sgrads = { k : np.zeros_like(v) for k,v in smodel.iteritems() } | |
sgrads['W2'] = np.dot(hs[1].T, ds['W2']) ; sgrads['b2'] = np.dot(hs[1].T, ds['b2']) | |
sgrads['W1'] = np.dot(hs[0].T, ds['W1']) | |
sgrads['b1'] = np.dot(hs[0].T, ds['b1']) | |
return sgrads | |
# evaluate training set accuracy | |
def eval_model(model): | |
X = mnist.test.images | |
y = mnist.test.labels | |
hidden_layer = np.maximum(0, np.dot(X, model['W1']) + model['b1']) | |
scores = np.dot(hidden_layer, model['W2']) + model['b2'] | |
predicted_class = np.argmax(scores, axis=1) | |
return np.mean(predicted_class == y), predicted_class | |
loss_history = [] | |
smoothing_factor = 0.95 | |
for i in xrange(global_step, 1000): | |
X, y = mnist.train.next_batch(batch_size) | |
probs, hs = forward(X, model) | |
synthetic_grads = sforward(hs, smodel, model) # compute synthetic gradients | |
# synthetic gradient model updates | |
if i % synth_step == 0: | |
# compute the mnist model loss | |
y_logprobs = -np.log(probs[range(batch_size),y]) | |
loss = np.sum(y_logprobs)/batch_size | |
if i is 0 : smooth_loss = loss | |
grads = backward(y, probs, hs, model) # data gradients | |
# compute the synthetic gradient loss | |
ds = {k : - v.ravel() + synthetic_grads[k] for (k, v) in grads.iteritems()} | |
sloss = np.sum([np.sum(slr*v*v) for v in ds.itervalues()])/batch_size | |
sgrads = sbackward(hs, ds, smodel) | |
smodel = {k : smodel[k] - slr*sgrads[k] for (k,v) in sgrads.iteritems()} # update smodel parameters | |
smooth_loss = (smoothing_factor*smooth_loss + (1-smoothing_factor)*loss) | |
loss_history.append((i,smooth_loss)) | |
# estimate gradients using synthetic gradient model | |
est_grad = {k : np.reshape(np.sum(v, axis=0), model[k].shape)/batch_size for k,v in synthetic_grads.iteritems()} | |
model = {k : model[k] - lr*v for (k,v) in est_grad.iteritems()} # update model using estimated gradient | |
# boring book-keeping | |
if (i+1) % 500 == 0: print "iteration {}: test accuracy {:3f}".format(i, eval_model(model)[0]) | |
if (i) % 100 == 0: print "\titeration {}: smooth_loss {:3f}, synth_loss {:3f}".format(i, smooth_loss, sloss) | |
global_step += 1 | |
# plot W2 gradients | |
plt.figure(0, figsize=(12,16)) | |
plt.subplot(121) ; plt.title("Actual gradient W2") | |
plt.imshow(grads['W2'].T[:,:50], cmap=cm.gray) | |
plt.subplot(122) ; plt.title("Synthetic gradient W2") | |
plt.imshow(grads['W2'].T[:,:50], cmap=cm.gray) | |
plt.show() | |
# plot loss over time | |
plt.figure(1) | |
plt.title("Smoothed loss") ; plt.xlabel("training steps") ; plt.ylabel("loss") | |
train_steps, smoothed_losses = zip(*loss_history) | |
plt.plot(train_steps, smoothed_losses) | |
plt.show() | |
# sample from test set with model | |
acc, predicted_class = eval_model(model) | |
X = mnist.test.images ; y = mnist.test.labels | |
r = np.random.randint(1000) | |
for i in range(0 + r,6 + r,3): | |
img1 = np.reshape(X[i+0,:], (28,28)) | |
img2 = np.reshape(X[i+1,:], (28,28)) | |
img3 = np.reshape(X[i+2,:], (28,28)) | |
plt.figure(2+i, figsize=(5,2)) | |
plt.subplot(131) ; plt.title("predicted: " + str(predicted_class[i])) | |
plt.imshow(img1, cmap=cm.gray) | |
plt.subplot(132) ; plt.title("predicted: " + str(predicted_class[i+1])) | |
plt.imshow(img2, cmap=cm.gray) | |
plt.subplot(133) ; plt.title("predicted: " + str(predicted_class[i+2])) | |
plt.imshow(img3, cmap=cm.gray) | |
plt.show() |
Thanx a ton for the code
Has the bug been fixed yet?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
EDIT: I made a mistake in my implementation of synthetic gradients. Runtime is not the issue I imagined it was because the targets of the synthetic gradient models should be the output activations of each layer rather than the actual gradients on the weights. I'm in the process of fixing this in my code. For now, just know that there's a bug in this implementation.