Created
August 25, 2019 12:52
-
-
Save yaroslavvb/6b08eb8f683b646785d39e97679fdb2e to your computer and use it in GitHub Desktop.
Example of computing Hessian of linear layer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test(): | |
u.seed_random(1) | |
data_width = 3 | |
targets_width = 2 | |
batch_size = 3 | |
dataset = TinyMNIST('/tmp', download=True, data_width=data_width, targets_width=targets_width, dataset_size=batch_size) | |
trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
d1 = data_width ** 2 # hidden layer size, visible size, output size | |
d2 = targets_width ** 2 # hidden layer size, visible size, output size | |
n = batch_size | |
model = Net([d1, d2]) | |
layer = model.layers[0] | |
W = model.layers[0].weight | |
skip_hooks = False | |
def capture_activations(module, input, _output): | |
if skip_hooks: | |
return | |
assert not hasattr(module, 'activations'), "Seeing results of previous autograd, call util.zero_grad to clear" | |
assert len(input) == 1, "this works for single input layers only" | |
setattr(module, "activations", input[0].detach()) | |
def capture_backprops(module: nn.Module, _input, output): | |
if skip_hooks: | |
return | |
assert not hasattr(module, 'backprops'), "Seeing results of previous autograd, call util.zero_grad to clear" | |
assert len(output) == 1, "this works for single variable layers only" | |
setattr(module, "backprops", output[0]) | |
layer.register_forward_hook(capture_activations) | |
layer.register_backward_hook(capture_backprops) | |
def loss_fn(data, targets): | |
err = data - targets.view(-1, data.shape[1]) | |
assert len(data) == batch_size | |
return torch.sum(err * err) / 2 / len(data) | |
# def unvec(x): return u.unvec(x, d) | |
# Gradient | |
data, targets = next(iter(trainloader)) | |
loss = loss_fn(model(data), targets) | |
loss.backward() | |
A = layer.activations.t() | |
assert A.shape == (d1, n) | |
# add factor of n here because backprop computes loss averaged over batch, while we need per-example loss backprop | |
B = layer.backprops.t() * n | |
assert B.shape == (d2, n) | |
u.check_close(B @ A.t() / n, W.grad) | |
# Hessian | |
skip_hooks = True | |
loss = loss_fn(model(data), targets) | |
H = u.hessian(loss, W) | |
H = H.transpose(0, 1).transpose(2, 3).reshape(d1 * d2, d1 * d2) | |
print(H) | |
# compute B matrices | |
Bs_t = [] # one matrix per class, storing backprops for current layer | |
skip_hooks = False | |
id_mat = torch.eye(d2) | |
for out_idx in range(d2): | |
u.zero_grad(model) | |
output = model(data) | |
_loss = loss_fn(output, targets) | |
ei = id_mat[out_idx] | |
bval = torch.stack([ei]*batch_size) | |
output.backward(bval) | |
Bs_t.append(layer.backprops) | |
A_t = layer.activations | |
# batch output Jacobian, each row corresponds to example,output pair | |
Amat = torch.cat([A_t]*d2, dim=0) | |
Bmat = torch.cat(Bs_t, dim=0) | |
Jb = u.khatri_rao_t(Amat, Bmat) | |
H2 = Jb.t() @ Jb / n | |
u.check_close(H, H2) | |
-- utils | |
def khatri_rao(A: torch.Tensor, B: torch.Tensor): | |
"""Khatri-Rao product. | |
i'th column of result C_i is a Kronecker product of A_i and B_i | |
Section 2.6 of Kolda, Tamara G., and Brett W. Bader. "Tensor decompositions and applications." SIAM review 51.3 | |
(2009): 455-500""" | |
assert A.shape[1] == B.shape[1] | |
# noinspection PyTypeChecker | |
return torch.einsum("ik,jk->ijk", A, B).reshape(A.shape[0] * B.shape[0], A.shape[1]) | |
def test_khatri_rao(): | |
A = torch.tensor([[1, 2], [3, 4]]) | |
B = torch.tensor([[5, 6], [7, 8]]) | |
C = torch.tensor([[5, 12], [7, 16], | |
[15, 24], [21, 32]]) | |
check_equal(khatri_rao(A, B), C) | |
def khatri_rao_t(A: torch.Tensor, B: torch.Tensor): | |
"""Transposed Khatri-Rao, inputs and outputs are transposed.""" | |
assert A.shape[0] == B.shape[0] | |
# noinspection PyTypeChecker | |
return torch.einsum("ki,kj->kij", A, B).reshape(A.shape[0], A.shape[1] * B.shape[1]) | |
def jacobian(y: torch.Tensor, x: torch.Tensor, create_graph=False): | |
jac = [] | |
flat_y = y.reshape(-1) | |
grad_y = torch.zeros_like(flat_y) | |
for i in range(len(flat_y)): | |
grad_y[i] = 1. | |
grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph) | |
jac.append(grad_x.reshape(x.shape)) | |
grad_y[i] = 0. | |
return torch.stack(jac).reshape(y.shape + x.shape) | |
def hessian(y: torch.Tensor, x: torch.Tensor): | |
return jacobian(jacobian(y, x, create_graph=True), x) | |
class TinyMNIST(datasets.MNIST): | |
"""Dataset for autoencoder task.""" | |
# 60k,1,new_dim,new_dim | |
def __init__(self, root, data_width=4, targets_width=4, dataset_size=60000, download=True): | |
super().__init__(root, download) | |
# Put both data and targets on GPU in advance | |
self.data = self.data[:dataset_size, :, :] | |
new_data = np.zeros((self.data.shape[0], data_width, data_width)) | |
new_targets = np.zeros((self.data.shape[0], targets_width, targets_width)) | |
for i in range(self.data.shape[0]): | |
arr = self.data[i, :].numpy().astype(np.uint8) | |
im = Image.fromarray(arr) | |
im.thumbnail((data_width, data_width), Image.ANTIALIAS) | |
new_data[i, :, :] = np.array(im) / 255 | |
im = Image.fromarray(arr) | |
im.thumbnail((targets_width, targets_width), Image.ANTIALIAS) | |
new_targets[i, :, :] = np.array(im) / 255 | |
self.data = torch.from_numpy(new_data).float() | |
self.data = self.data.unsqueeze(1) | |
self.targets = torch.from_numpy(new_targets).float() | |
self.targets = self.targets.unsqueeze(1) | |
self.data, self.targets = self.data.to(device), self.targets.to(device) | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (image, target) where target is index of the target class. | |
""" | |
img, target = self.data[index], self.targets[index] | |
return img, target | |
class Net(nn.Module): | |
def __init__(self, d: List[int], nonlin=False): | |
super().__init__() | |
self.layers: List[nn.Module] = [] | |
self.all_layers: List[nn.Module] = [] | |
self.d: List[int] = d | |
for i in range(len(d) - 1): | |
linear = nn.Linear(d[i], d[i + 1], bias=False) | |
setattr(linear, 'name', f'{i:02d}-linear') | |
self.layers.append(linear) | |
self.all_layers.append(linear) | |
if nonlin: | |
self.all_layers.append(nn.ReLU()) | |
self.predict = torch.nn.Sequential(*self.all_layers) | |
def forward(self, x: torch.Tensor): | |
x = x.reshape((-1, self.d[0])) | |
return self.predict(x) | |
def seed_random(seed): | |
torch.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment