Skip to content

Instantly share code, notes, and snippets.

@wiseodd
Created April 2, 2024 18:11
Show Gist options
  • Save wiseodd/426061afae24199446e60bfabc00e26e to your computer and use it in GitHub Desktop.
Save wiseodd/426061afae24199446e60bfabc00e26e to your computer and use it in GitHub Desktop.
from laplace import Laplace
from laplace.curvature import CurvlinopsGGN
import torch
from torch import nn
import torch.utils.data as data_utils
from collections import UserDict
import collections.abc as cols_abc
class Model(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 2)
)
def forward(self, data: UserDict | torch.Tensor):
if isinstance(data, UserDict):
device = next(self.net.parameters()).device
X = data['input_ids'].to(device)
else:
X = data
return self.net(X)
X, y = torch.randn(100, 10), torch.randn(100, 2)
dataset = data_utils.TensorDataset(X, y)
def collate_fn(data_list):
"""
data_list: List[Tuple[x, y]]
"""
input_ids, labels = [], []
for x, y in data_list:
input_ids.append(x.unsqueeze(0))
labels.append(y.unsqueeze(0))
return UserDict({
'input_ids': torch.cat(input_ids, dim=0),
'labels': torch.cat(labels, dim=0)
})
dataloader = data_utils.DataLoader(dataset, batch_size=50, collate_fn=collate_fn)
model = Model()
la = Laplace(model, likelihood='regression', subset_of_weights='all', hessian_structure='kron', backend=CurvlinopsGGN)
la.fit(dataloader)
print('Success with UserDict!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment