Last active
March 26, 2021 12:51
-
-
Save daydreamt/b702def4f8562b8b2f8f57dc561eab18 to your computer and use it in GitHub Desktop.
Bayesian Dark Knowledge: A modification of https://github.com/google/jax/blob/master/examples/mnist_classifier.py implementing https://arxiv.org/abs/1506.04416
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as np | |
from jax.experimental import stax | |
from jax.experimental import optimizers | |
from jax.experimental.stax import Dense, Relu, Tanh, Softmax, LogSoftmax | |
from jax import jit, grad, random | |
import time | |
import itertools | |
import numpy.random as npr | |
import jax.numpy as np | |
from jax.config import config | |
from jax import jit, grad, random | |
from jax.experimental import optimizers | |
from jax.experimental.optimizers import optimizer, sgd, make_schedule, exponential_decay, l2_norm | |
from jax.experimental import stax | |
from jax.experimental.stax import Dense, Relu, LogSoftmax | |
from datasets import get_mnist | |
@optimizer | |
def sgld(step_size): | |
"""Construct optimizer triple for stochastic gradient descent. | |
Args: | |
step_size: positive scalar, or a callable representing a step size schedule | |
that maps the iteration index to positive scalar. | |
Returns: | |
An (init_fun, update_fun, get_params) triple. | |
""" | |
step_size = make_schedule(step_size) | |
def init(x0): | |
return x0 | |
def update(i, g, x): | |
return x - step_size(i) / 2 * g - step_size(i) * random.normal(random.PRNGKey(i)) | |
def get_params(x): | |
return x | |
return init, update, get_params | |
# Classification loss (cross entropy) | |
def loss(params, batch, regularize=True): | |
inputs, targets = batch | |
preds = predict(params, inputs) | |
loss_ = -np.mean(np.sum(preds * targets, axis=1)) | |
if regularize: loss_ += l2_norm(params) | |
return loss_ | |
# Student classification loss ( cross entropy with teachers predictions) | |
def student_loss(params, batch, params_teacher, regularize=False): | |
inputs, _ = batch # Ignore the targets! | |
teacher_preds = predict(params_teacher, inputs) | |
student_preds = student_predict(params, inputs) | |
loss_ = -np.mean(np.sum(teacher_preds * student_preds, axis=1)) | |
if regularize: loss_ += l2_norm(params) | |
return loss_ | |
def accuracy(params, batch): | |
inputs, targets = batch | |
target_class = np.argmax(targets, axis=1) | |
predicted_class = np.argmax(predict(params, inputs), axis=1) | |
return np.mean(predicted_class == target_class) | |
def accuracy_student(params, batch): | |
"We don't care how well the student is predicting the teacher for now, just the real thing!" | |
inputs, targets = batch | |
target_class = np.argmax(targets, axis=1) | |
predicted_class = np.argmax(student_predict(params, inputs), axis=1) | |
return np.mean(predicted_class == target_class) | |
def accuracy_student_teacher(params, params_student, batch): | |
"We don't care how well the student is predicting the teacher for now, just the real thing!" | |
inputs, targets = batch | |
#target_class = np.argmax(targets, axis=1) | |
predicted_class_student = np.argmax(student_predict(params_student, inputs), axis=1) | |
predicted_class_teacher = np.argmax(predict(params, inputs), axis=1) | |
return np.mean(predicted_class_student == predicted_class_teacher) | |
init_random_params, predict = stax.serial( | |
Dense(400), Relu, | |
Dense(400), Relu, | |
Dense(10), LogSoftmax) | |
init_random_params_student, student_predict = stax.serial( | |
Dense(400), Relu, | |
Dense(400), Relu, | |
Dense(10), LogSoftmax) | |
if __name__ == "__main__": | |
step_size = exponential_decay(0.001, 10, 0.99) | |
step_size = 5 * 10**-6 # BDK paper | |
def key_function(key): | |
while True: | |
key, subkey = jax.random.split(key) | |
yield key | |
num_epochs = 2000 | |
batch_size = 100 | |
train_images, train_labels, test_images, test_labels = get_mnist() | |
num_train = train_images.shape[0] | |
num_complete_batches, leftover = divmod(num_train, batch_size) | |
num_batches = num_complete_batches + bool(leftover) | |
def data_stream(): | |
rng = npr.RandomState(0) | |
while True: | |
perm = rng.permutation(num_train) | |
for i in range(num_batches): | |
batch_idx = perm[i * batch_size:(i + 1) * batch_size] | |
yield train_images[batch_idx], train_labels[batch_idx] | |
batches = data_stream() | |
key = random.PRNGKey(0) | |
opt_init, opt_update, get_params = sgld(step_size) | |
opt_init_student, opt_update_student, get_params_student = sgd(step_size) | |
@jit | |
def update(i, opt_state, batch): | |
params = get_params(opt_state) | |
return opt_update(i, grad(loss)(params, batch), opt_state) | |
@jit | |
def update_student(i, opt_state_student, batch, params_teacher): | |
params_student = get_params_student(opt_state_student) | |
return opt_update_student(i, grad(student_loss)(params, batch, params_teacher), opt_state_student) | |
_, init_params = init_random_params(key, (-1, 28 * 28)) | |
opt_state = opt_init(init_params) | |
itercount = itertools.count() | |
_, init_student_params = init_random_params_student(key, (-1, 28 * 28)) | |
opt_state_student = opt_init_student(init_student_params) | |
itercount_student = itertools.count() | |
print("\nStarting training...") | |
for epoch in range(num_epochs): | |
start_time = time.time() | |
for _ in range(num_batches): | |
opt_state = update(next(itercount), opt_state, next(batches)) | |
params = get_params(opt_state) # Teacher params | |
params_teacher = params | |
for _ in range(num_batches): | |
opt_state_student = update_student(next(itercount_student), opt_state_student, next(batches), params_teacher) | |
params_student = get_params_student(opt_state_student) | |
epoch_time = time.time() - start_time | |
train_acc = accuracy(params, (train_images, train_labels)) | |
test_acc = accuracy(params, (test_images, test_labels)) | |
train_acc_student = accuracy_student(params_student, (train_images, train_labels)) | |
test_acc_student = accuracy_student(params_student, (test_images, test_labels)) | |
train_acc_student_teacher = accuracy_student_teacher(params, params_student, (train_images, train_labels)) | |
test_acc_student_teacher = accuracy_student_teacher(params, params_student, (test_images, test_labels)) | |
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) | |
print("Training set accuracy - teacher {}".format(train_acc)) | |
print("Test set accuracy {} - teacher".format(test_acc)) | |
print("Training set accuracy - student {}".format(train_acc_student)) | |
print("Test set accuracy {} - student".format(test_acc_student)) | |
print("Training set accuracy - student == teacher {}".format(train_acc_student_teacher)) | |
print("Test set accuracy {} - student == teacher".format(test_acc_student_teacher)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment