Skip to content

Instantly share code, notes, and snippets.

@andrewliao11
Last active March 21, 2019 05:17
Show Gist options
  • Save andrewliao11/ec4967d8f4570d6192756b17dd9665b8 to your computer and use it in GitHub Desktop.
Save andrewliao11/ec4967d8f4570d6192756b17dd9665b8 to your computer and use it in GitHub Desktop.
import torch
import imageio
import numpy as np
import seaborn
import matplotlib.pyplot as plt
import matplotlib
torch.manual_seed(1)
# data generation: y = ax + b
n_data = 10
n_dim = 5
true_weight = torch.rand([n_dim, 1])
true_bias = torch.rand(1)
xs = (torch.rand([n_data, n_dim]) * 5).view(n_data, n_dim).float()
ys = torch.mm(xs, true_weight) + true_bias
# forward pass
std = 1.
def forward(x, w, b):
x = x.view(-1, n_dim)
hat_y = torch.mm(x, w) + b
dist = torch.distributions.normal.Normal(hat_y, std)
return hat_y, dist
def fisher_info_matrix(xs, ys, w, b):
emperical_fisher = []
for i in range(n_data):
x, y = xs[i], ys[i]
hat_y, dist = forward(x, w, b)
log_prob = dist.log_prob(y)
grads = torch.autograd.grad(log_prob, [w, b])
flat_grads = torch.cat([grad.view(-1) for grad in grads])
flat_grads = flat_grads.unsqueeze(1)
fisher = flat_grads * flat_grads.transpose(1, 0)
emperical_fisher.append(fisher)
emperical_fisher = sum(emperical_fisher) / n_data
return emperical_fisher
def show_fim(fim, i):
fig, ax = plt.subplots(figsize=(5,5))
ax.matshow(fim/torch.max(fim))
ax.set_title('Iteration {}'.format(i))
#ax.grid()
# Used to return the plot as an image array
fig.canvas.draw() # draw the canvas, cache the renderer
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return image
def plot_loss(losses):
plt.plot(losses)
plt.title('loss')
plt.savefig('losses.png')
# initialize weights
init_weight = torch.normal(torch.zeros(true_weight.size()), torch.ones(true_weight.size()) * 0.1)
init_bias = torch.zeros(true_bias.size())
weight = init_weight
bias = init_bias
weight.requires_grad = True
bias.requires_grad = True
# optimize loop
n_steps = 50
optim = torch.optim.SGD([weight, bias], 0.01, momentum=0.9, weight_decay=0.1)
imgs = []
losses = []
for i in range(n_steps):
fim = fisher_info_matrix(xs, ys, weight, bias)
img = show_fim(fim)
imgs.append(img)
optim.zero_grad()
hat_y, dist = forward(xs, weight, bias)
log_prob = dist.log_prob(ys)
neg_log_prob = -log_prob
loss = neg_log_prob.mean()
loss.backward()
losses.append(loss.item())
optim.step()
imageio.mimsave('./fim_change.gif', imgs, fps=10)
plot_loss(losses)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment