Last active
March 21, 2019 05:17
-
-
Save andrewliao11/ec4967d8f4570d6192756b17dd9665b8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import imageio | |
import numpy as np | |
import seaborn | |
import matplotlib.pyplot as plt | |
import matplotlib | |
torch.manual_seed(1) | |
# data generation: y = ax + b | |
n_data = 10 | |
n_dim = 5 | |
true_weight = torch.rand([n_dim, 1]) | |
true_bias = torch.rand(1) | |
xs = (torch.rand([n_data, n_dim]) * 5).view(n_data, n_dim).float() | |
ys = torch.mm(xs, true_weight) + true_bias | |
# forward pass | |
std = 1. | |
def forward(x, w, b): | |
x = x.view(-1, n_dim) | |
hat_y = torch.mm(x, w) + b | |
dist = torch.distributions.normal.Normal(hat_y, std) | |
return hat_y, dist | |
def fisher_info_matrix(xs, ys, w, b): | |
emperical_fisher = [] | |
for i in range(n_data): | |
x, y = xs[i], ys[i] | |
hat_y, dist = forward(x, w, b) | |
log_prob = dist.log_prob(y) | |
grads = torch.autograd.grad(log_prob, [w, b]) | |
flat_grads = torch.cat([grad.view(-1) for grad in grads]) | |
flat_grads = flat_grads.unsqueeze(1) | |
fisher = flat_grads * flat_grads.transpose(1, 0) | |
emperical_fisher.append(fisher) | |
emperical_fisher = sum(emperical_fisher) / n_data | |
return emperical_fisher | |
def show_fim(fim, i): | |
fig, ax = plt.subplots(figsize=(5,5)) | |
ax.matshow(fim/torch.max(fim)) | |
ax.set_title('Iteration {}'.format(i)) | |
#ax.grid() | |
# Used to return the plot as an image array | |
fig.canvas.draw() # draw the canvas, cache the renderer | |
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') | |
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return image | |
def plot_loss(losses): | |
plt.plot(losses) | |
plt.title('loss') | |
plt.savefig('losses.png') | |
# initialize weights | |
init_weight = torch.normal(torch.zeros(true_weight.size()), torch.ones(true_weight.size()) * 0.1) | |
init_bias = torch.zeros(true_bias.size()) | |
weight = init_weight | |
bias = init_bias | |
weight.requires_grad = True | |
bias.requires_grad = True | |
# optimize loop | |
n_steps = 50 | |
optim = torch.optim.SGD([weight, bias], 0.01, momentum=0.9, weight_decay=0.1) | |
imgs = [] | |
losses = [] | |
for i in range(n_steps): | |
fim = fisher_info_matrix(xs, ys, weight, bias) | |
img = show_fim(fim) | |
imgs.append(img) | |
optim.zero_grad() | |
hat_y, dist = forward(xs, weight, bias) | |
log_prob = dist.log_prob(ys) | |
neg_log_prob = -log_prob | |
loss = neg_log_prob.mean() | |
loss.backward() | |
losses.append(loss.item()) | |
optim.step() | |
imageio.mimsave('./fim_change.gif', imgs, fps=10) | |
plot_loss(losses) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment