Skip to content

Instantly share code, notes, and snippets.

@p208p2002
Created August 8, 2023 07:59
Show Gist options
  • Save p208p2002/348b095fde60fa0d09e0490109de6c5c to your computer and use it in GitHub Desktop.
Save p208p2002/348b095fde60fa0d09e0490109de6c5c to your computer and use it in GitHub Desktop.
# https://huggingface.co/docs/transformers/perplexity
from typing import Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class PPL():
def __init__(self, model_id="gpt2") -> None:
self.model = AutoModelForCausalLM.from_pretrained(model_id)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.device = 'cpu'
def __call__(self, text) -> Any:
encodings = self.tokenizer(text, return_tensors = "pt")
max_length = self.model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in range(0, seq_len, stride):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(self.device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = self.model(input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over valid labels
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
# to the left by 1.
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
return torch.exp(torch.stack(nlls).mean()).item()
if __name__ == "__main__":
ppl =PPL()
print(ppl("ABC is a startup based in New York City and Paris"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment