Skip to content

Instantly share code, notes, and snippets.

@daydreamt
Last active March 26, 2021 12:51
Show Gist options
  • Save daydreamt/b702def4f8562b8b2f8f57dc561eab18 to your computer and use it in GitHub Desktop.
Save daydreamt/b702def4f8562b8b2f8f57dc561eab18 to your computer and use it in GitHub Desktop.
import jax
import jax.numpy as np
from jax.experimental import stax
from jax.experimental import optimizers
from jax.experimental.stax import Dense, Relu, Tanh, Softmax, LogSoftmax
from jax import jit, grad, random
import time
import itertools
import numpy.random as npr
import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental.optimizers import optimizer, sgd, make_schedule, exponential_decay, l2_norm
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
from datasets import get_mnist
@optimizer
def sgld(step_size):
"""Construct optimizer triple for stochastic gradient descent.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
return x0
def update(i, g, x):
return x - step_size(i) / 2 * g - step_size(i) * random.normal(random.PRNGKey(i))
def get_params(x):
return x
return init, update, get_params
# Classification loss (cross entropy)
def loss(params, batch, regularize=True):
inputs, targets = batch
preds = predict(params, inputs)
loss_ = -np.mean(np.sum(preds * targets, axis=1))
if regularize: loss_ += l2_norm(params)
return loss_
# Student classification loss ( cross entropy with teachers predictions)
def student_loss(params, batch, params_teacher, regularize=False):
inputs, _ = batch # Ignore the targets!
teacher_preds = predict(params_teacher, inputs)
student_preds = student_predict(params, inputs)
loss_ = -np.mean(np.sum(teacher_preds * student_preds, axis=1))
if regularize: loss_ += l2_norm(params)
return loss_
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
def accuracy_student(params, batch):
"We don't care how well the student is predicting the teacher for now, just the real thing!"
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(student_predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
def accuracy_student_teacher(params, params_student, batch):
"We don't care how well the student is predicting the teacher for now, just the real thing!"
inputs, targets = batch
#target_class = np.argmax(targets, axis=1)
predicted_class_student = np.argmax(student_predict(params_student, inputs), axis=1)
predicted_class_teacher = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class_student == predicted_class_teacher)
init_random_params, predict = stax.serial(
Dense(400), Relu,
Dense(400), Relu,
Dense(10), LogSoftmax)
init_random_params_student, student_predict = stax.serial(
Dense(400), Relu,
Dense(400), Relu,
Dense(10), LogSoftmax)
if __name__ == "__main__":
step_size = exponential_decay(0.001, 10, 0.99)
step_size = 5 * 10**-6 # BDK paper
def key_function(key):
while True:
key, subkey = jax.random.split(key)
yield key
num_epochs = 2000
batch_size = 100
train_images, train_labels, test_images, test_labels = get_mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
key = random.PRNGKey(0)
opt_init, opt_update, get_params = sgld(step_size)
opt_init_student, opt_update_student, get_params_student = sgd(step_size)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
@jit
def update_student(i, opt_state_student, batch, params_teacher):
params_student = get_params_student(opt_state_student)
return opt_update_student(i, grad(student_loss)(params, batch, params_teacher), opt_state_student)
_, init_params = init_random_params(key, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
_, init_student_params = init_random_params_student(key, (-1, 28 * 28))
opt_state_student = opt_init_student(init_student_params)
itercount_student = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
params = get_params(opt_state) # Teacher params
params_teacher = params
for _ in range(num_batches):
opt_state_student = update_student(next(itercount_student), opt_state_student, next(batches), params_teacher)
params_student = get_params_student(opt_state_student)
epoch_time = time.time() - start_time
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
train_acc_student = accuracy_student(params_student, (train_images, train_labels))
test_acc_student = accuracy_student(params_student, (test_images, test_labels))
train_acc_student_teacher = accuracy_student_teacher(params, params_student, (train_images, train_labels))
test_acc_student_teacher = accuracy_student_teacher(params, params_student, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy - teacher {}".format(train_acc))
print("Test set accuracy {} - teacher".format(test_acc))
print("Training set accuracy - student {}".format(train_acc_student))
print("Test set accuracy {} - student".format(test_acc_student))
print("Training set accuracy - student == teacher {}".format(train_acc_student_teacher))
print("Test set accuracy {} - student == teacher".format(test_acc_student_teacher))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment