Created
January 18, 2024 08:54
-
-
Save dvruette/9858bd870d4d23f963a519aabb2e048e to your computer and use it in GitHub Desktop.
Weight decay to model initialization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import torch | |
import torch.nn as nn | |
class DecayToInit(nn.Module): | |
def __init__(self, param: torch.Tensor): | |
super().__init__() | |
self.register_buffer("param", param) | |
def forward(self, delta: torch.Tensor) -> torch.Tensor: | |
return self.param + delta | |
def add_decay_to_init(model: nn.Module): | |
for m in list(model.modules()): | |
for name, param in list(m.named_parameters(recurse=False)): | |
init_weights = param.data.clone().detach() | |
param.data.zero_() | |
nn.utils.parametrize.register_parametrization(m, name, DecayToInit(init_weights)) | |
def merge_decay_to_init(model: nn.Module): | |
for m in list(model.modules()): | |
if nn.utils.parametrize.is_parametrized(m): | |
for name, _ in list(m.named_parameters(recurse=True)): | |
original_name = name.replace("parametrizations.", "").replace(".original", "") | |
nn.utils.parametrize.remove_parametrizations(m, original_name, leave_parametrized=True) | |
def train(model, seed=0): | |
torch.manual_seed(seed) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-1) | |
N = 4096 | |
xs = torch.randn(N, 768) | |
ys = torch.randint(0, 10, (N,)) | |
dataset = torch.utils.data.TensorDataset(xs, ys) | |
dl = torch.utils.data.DataLoader(dataset, batch_size=32) | |
device = next(model.parameters()).device | |
step = 0 | |
for epoch in range(10): | |
for x, y in dl: | |
x, y = x.to(device), y.to(device) | |
optimizer.zero_grad() | |
logits = model(x) | |
loss = nn.CrossEntropyLoss()(logits, y) | |
loss.backward() | |
optimizer.step() | |
step += 1 | |
print(f"step {step:4d}: {loss.item():.3f}") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.manual_seed(42) | |
model_init = nn.Sequential( | |
nn.Linear(768, 128), | |
nn.ReLU(), | |
nn.Linear(128, 128), | |
nn.ReLU(), | |
nn.Linear(128, 10), | |
).to(device) | |
for p in model_init.parameters(): | |
torch.nn.init.uniform_(p, 0, 1) | |
model_a = copy.deepcopy(model_init) | |
model_b = copy.deepcopy(model_init) | |
print("=== Training w/ decay to init ======") | |
add_decay_to_init(model_a) | |
train(model_a) | |
merge_decay_to_init(model_a) | |
print("=== Training w/o decay to init =====") | |
train(model_b) | |
print("====================================") | |
for init, a, b in zip(model_init.parameters(), model_a.parameters(), model_b.parameters()): | |
diff_a = torch.norm(init - a, p=2) | |
diff_b = torch.norm(init - b, p=2) | |
print(f"{diff_a.item():.3f}", f"{diff_b.item():.3f}", "a is closer" if diff_a < diff_b else "b is closer") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment