Created
August 8, 2023 07:59
-
-
Save p208p2002/348b095fde60fa0d09e0490109de6c5c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://huggingface.co/docs/transformers/perplexity | |
from typing import Any | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
class PPL(): | |
def __init__(self, model_id="gpt2") -> None: | |
self.model = AutoModelForCausalLM.from_pretrained(model_id) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
self.device = 'cpu' | |
def __call__(self, text) -> Any: | |
encodings = self.tokenizer(text, return_tensors = "pt") | |
max_length = self.model.config.n_positions | |
stride = 512 | |
seq_len = encodings.input_ids.size(1) | |
nlls = [] | |
prev_end_loc = 0 | |
for begin_loc in range(0, seq_len, stride): | |
end_loc = min(begin_loc + max_length, seq_len) | |
trg_len = end_loc - prev_end_loc # may be different from stride on last loop | |
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(self.device) | |
target_ids = input_ids.clone() | |
target_ids[:, :-trg_len] = -100 | |
with torch.no_grad(): | |
outputs = self.model(input_ids, labels=target_ids) | |
# loss is calculated using CrossEntropyLoss which averages over valid labels | |
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels | |
# to the left by 1. | |
neg_log_likelihood = outputs.loss | |
nlls.append(neg_log_likelihood) | |
prev_end_loc = end_loc | |
if end_loc == seq_len: | |
break | |
return torch.exp(torch.stack(nlls).mean()).item() | |
if __name__ == "__main__": | |
ppl =PPL() | |
print(ppl("ABC is a startup based in New York City and Paris")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment