Created
July 18, 2018 05:02
-
-
Save rkaplan/873fb8b90f6828e49c56f27e9ed06bf0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import torch | |
class DynamicNet(torch.nn.Module): | |
def __init__(self, D_in, H, D_out): | |
super(DynamicNet, self).__init__() | |
self.backbone = torch.nn.Linear(D_in, H) | |
self.head1 = torch.nn.Linear(H, D_out) | |
self.head2 = torch.nn.Linear(H, D_out) | |
def forward(self, x, use_head1=True): | |
h = self.backbone(x).clamp(min=0) | |
if use_head1: | |
return self.head1(h) | |
else: | |
return self.head2(h) | |
N, D_in, H, D_out = 8, 10, 30, 2 | |
x = torch.randn(N, D_in) | |
y = torch.randn(N, D_out) | |
model = DynamicNet(D_in, H, D_out) | |
criterion = torch.nn.MSELoss(size_average=False) | |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) | |
MAKE_ZERO_GRADS_NONE = True | |
for t in range(5): | |
y_pred = model(x, use_head1=t==0) | |
loss = criterion(y_pred, y) | |
print('Iter {}:'.format(t), loss.item()) | |
optimizer.zero_grad() | |
loss.backward() | |
print('Grads norms:') | |
for name, module in zip(['backbone', 'head1', 'head2'], [model.backbone, model.head1, model.head2]): | |
if MAKE_ZERO_GRADS_NONE and name == 'head1' and t > 0: | |
module.weight.grad = None | |
module.bias.grad = None | |
print(name, module.weight.grad.norm() if module.weight.grad is not None else 'None') | |
temp_weights_h1 = torch.tensor(model.head1.weight.data) | |
temp_weights_h2 = torch.tensor(model.head2.weight.data) | |
optimizer.step() | |
print('Norm of the head1 update:', (model.head1.weight - temp_weights_h1).norm()) | |
print('Norm of the head2 update:', (model.head2.weight - temp_weights_h2).norm()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output with
MAKE_ZERO_GRADS_NONE = True
:Output with
MAKE_ZERO_GRADS_NONE = False
: