Created
May 16, 2024 14:51
-
-
Save wiseodd/b29973cd96f96f3af620ca131571eaa4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import MutableMapping | |
from collections import UserDict | |
import numpy | |
import torch | |
from torch import nn | |
import torch.utils.data as data_utils | |
from laplace import Laplace | |
from laplace.curvature import CurvlinopsGGN, AsdlGGN | |
import logging | |
import warnings | |
logging.basicConfig(level='ERROR') | |
warnings.filterwarnings('ignore') | |
from transformers import ( # noqa: E402 | |
GPT2Config, | |
GPT2ForSequenceClassification, | |
GPT2Tokenizer, | |
DataCollatorWithPadding, | |
PreTrainedTokenizer, | |
) | |
from peft import LoraConfig, get_peft_model # noqa: E402 | |
from datasets import Dataset # noqa: E402 | |
# make deterministic | |
torch.manual_seed(0) | |
numpy.random.seed(0) | |
DEVIE = "cuda" | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
data = [ | |
{'text': 'Today is hot, but I will manage!!!!' * 75, 'label': 231.213}, | |
{'text': 'Tomorrow is cold' * 150, 'label': 3243.43}, | |
{'text': 'Carpe diem' * 150, 'label': 4343.43}, | |
{'text': 'Tempus fugit' * 150, 'label': 2133.3}, | |
] | |
dataset = Dataset.from_list(data) | |
def tokenize(row): | |
return tokenizer(row['text']) | |
dataset = dataset.map(tokenize, remove_columns=['text']) | |
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) | |
dataloader = data_utils.DataLoader( | |
dataset, batch_size=100, collate_fn=DataCollatorWithPadding(tokenizer) | |
) | |
data = next(iter(dataloader)) | |
print( | |
f'Huggingface data defaults to UserDict, which is a MutableMapping? {isinstance(data, UserDict)}' | |
) | |
for k, v in data.items(): | |
print(k, v.shape) | |
class MyGPT2(nn.Module): | |
""" | |
Huggingface LLM wrapper. | |
Args: | |
tokenizer: The tokenizer used for preprocessing the text data. Needed | |
since the model needs to know the padding token id. | |
""" | |
def __init__(self, tokenizer: PreTrainedTokenizer) -> None: | |
super().__init__() | |
config = GPT2Config.from_pretrained('gpt2') | |
config.pad_token_id = tokenizer.pad_token_id | |
config.num_labels = 1 | |
self.hf_model = GPT2ForSequenceClassification.from_pretrained( | |
'gpt2', config=config | |
) | |
self.hf_model.to(DEVICE) | |
def forward(self, data: MutableMapping) -> torch.Tensor: | |
""" | |
Custom forward function. Handles things like moving the | |
input tensor to the correct device inside. | |
Args: | |
data: A dict-like data structure with `input_ids` inside. | |
This is the default data structure assumed by Huggingface | |
dataloaders. | |
Returns: | |
logits: An `(batch_size, n_classes)`-sized tensor of logits. | |
""" | |
device = next(self.parameters()).device | |
input_ids = data['input_ids'].to(device) | |
attn_mask = data['attention_mask'].to(device) | |
output_dict = self.hf_model(input_ids=input_ids, attention_mask=attn_mask) | |
return output_dict.logits | |
model = MyGPT2(tokenizer).to(DEVIE) | |
# Laplace on the LoRA-attached LLM | |
# -------------------------------- | |
def get_lora_model(): | |
model = MyGPT2(tokenizer) # Note we don't disable grad | |
config = LoraConfig( | |
r=4, | |
lora_alpha=16, | |
target_modules=['c_attn'], # LoRA on the attention weights | |
lora_dropout=0.1, | |
bias='none', | |
) | |
lora_model = get_peft_model(model, config) | |
return lora_model | |
lora_model = get_lora_model() | |
# Train it as usual | |
lora_model.eval() | |
lora_la = Laplace( | |
lora_model, | |
likelihood='regression', | |
subset_of_weights='all', | |
hessian_structure='diag', | |
backend=AsdlGGN, | |
) | |
# for data in dataloader: | |
# print(data) | |
# input() | |
lora_la.fit(dataloader) | |
X_test = next(iter(dataloader)) | |
f_mean, f_var = lora_la(X_test, pred_type='glm') | |
print(f'[LoRA-LLM] The predictive tensor is of shape: {f_mean.shape} {f_var.shape}.') | |
# Should be OOM |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment