Created
February 21, 2024 16:24
-
-
Save catid/e0702e92ca4d225bb08e5936ba679753 to your computer and use it in GitHub Desktop.
Never Forget
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
#torch.autograd.set_detect_anomaly(True) | |
class FeedForward(torch.nn.Module): | |
def __init__(self, input_features, output_features): | |
super().__init__() | |
self.input_features = input_features | |
self.output_features = output_features | |
hidden_features = input_features * 4 | |
self.proj_in = nn.Linear(input_features, hidden_features) | |
self.act = nn.GELU() | |
self.proj_out = nn.Linear(hidden_features, output_features) | |
def forward(self, x): | |
x = self.proj_in(x) | |
x = self.act(x) | |
x = self.proj_out(x) | |
return x | |
if __name__ == "__main__": | |
torch.manual_seed(5) | |
input_features, output_features = 3, 2 | |
batch_size = 4 | |
model = FeedForward(input_features, output_features) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) | |
criterion = torch.nn.L1Loss() | |
x0 = torch.randn(batch_size, input_features) | |
y0 = torch.randn(batch_size, output_features) | |
model.train() | |
for epoch in range(1000): | |
optimizer.zero_grad() | |
output = model(x0) | |
loss = criterion(output, y0) | |
loss.backward() | |
#print(f"model.weight.grad = {model.weight.grad}") | |
optimizer.step() | |
# Store model weights from first training run, and calculate normalized FIM coefficients | |
original_params = {} | |
fims = {} | |
for name, param in model.named_parameters(): | |
if param.requires_grad: | |
original_params[name] = param.data.clone().detach() | |
fim = param.grad ** 2 | |
fim_norm = torch.norm(fim, p=1) | |
if fim_norm > 0: # Avoid division by zero | |
normalized_fim = fim / fim_norm | |
else: | |
normalized_fim = fim | |
fims[name] = normalized_fim | |
loss = criterion(model(x0), y0) | |
print(f"dataset0: epoch {epoch} loss = {loss.item()}") | |
x1 = torch.randn(batch_size, input_features) | |
y1 = torch.randn(batch_size, output_features) | |
model.train() | |
for epoch in range(1000): | |
optimizer.zero_grad() | |
output = model(x1) | |
loss = criterion(output, y1) | |
preloss = loss.clone() | |
for name, param in model.named_parameters(): | |
# Compute penalty as the sum of FIM values times squared parameter changes | |
fim_contribution = fims[name] * (param.data - original_params[name]).pow(2) | |
#fim_contribution = (param.data - original_params[name]).pow(2) | |
loss += 1.0 * fim_contribution.sum() | |
#print(f"loss delta={(loss-preloss)*100/preloss}%") | |
loss.backward() | |
if epoch == 10: | |
for name, param in model.named_parameters(): | |
if param.requires_grad: | |
print(f"param {name}: param.data={param.grad}") | |
optimizer.step() | |
loss = criterion(model(x1), y1) | |
print(f"dataset1: epoch {epoch} loss = {loss.item()}") | |
loss = criterion(model(x0), y0) | |
print(f"dataset0: epoch {epoch} loss = {loss.item()} (forgotten)") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment