Created
April 2, 2024 18:11
-
-
Save wiseodd/426061afae24199446e60bfabc00e26e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from laplace import Laplace | |
from laplace.curvature import CurvlinopsGGN | |
import torch | |
from torch import nn | |
import torch.utils.data as data_utils | |
from collections import UserDict | |
import collections.abc as cols_abc | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(10, 50), | |
nn.ReLU(), | |
nn.Linear(50, 2) | |
) | |
def forward(self, data: UserDict | torch.Tensor): | |
if isinstance(data, UserDict): | |
device = next(self.net.parameters()).device | |
X = data['input_ids'].to(device) | |
else: | |
X = data | |
return self.net(X) | |
X, y = torch.randn(100, 10), torch.randn(100, 2) | |
dataset = data_utils.TensorDataset(X, y) | |
def collate_fn(data_list): | |
""" | |
data_list: List[Tuple[x, y]] | |
""" | |
input_ids, labels = [], [] | |
for x, y in data_list: | |
input_ids.append(x.unsqueeze(0)) | |
labels.append(y.unsqueeze(0)) | |
return UserDict({ | |
'input_ids': torch.cat(input_ids, dim=0), | |
'labels': torch.cat(labels, dim=0) | |
}) | |
dataloader = data_utils.DataLoader(dataset, batch_size=50, collate_fn=collate_fn) | |
model = Model() | |
la = Laplace(model, likelihood='regression', subset_of_weights='all', hessian_structure='kron', backend=CurvlinopsGGN) | |
la.fit(dataloader) | |
print('Success with UserDict!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment