Last active
May 3, 2024 01:17
-
-
Save VTSTech/924a94abb5e315e8c8aba200d5b9131b to your computer and use it in GitHub Desktop.
VTSTech-PERP - Python script that computes perplexity on GPT Models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Program: VTSTech-PERP.py 2023-04-17 6:14:21PM | |
# Description: Python script that computes perplexity on GPT Models | |
# Author: Written by Veritas//VTSTech ([email protected]) | |
# GitHub: https://github.com/VTSTech | |
# Homepage: www.VTS-Tech.org | |
# Use a 'train.txt' for it to predict with. Any large english text will do | |
# pip install torch argparse transformers colorama | |
import torch | |
import argparse | |
import time | |
import warnings | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataset import Dataset | |
from typing import List | |
from colorama import Fore, Back, Style, init | |
init(autoreset=True) | |
build="v0.1-r01" | |
DEFAULT_MAX_LENGTH = 512 | |
max_length=DEFAULT_MAX_LENGTH | |
def banner(): | |
global model_name | |
if not args.clean: | |
print(Style.BRIGHT + f"VTSTech-PERP {build} - www: VTS-Tech.org git: VTSTech discord.gg/P4RDD76U") | |
print("Using Model : " + Fore.RED + f"{model_name}") | |
# Load the tokenizer and model | |
parser = argparse.ArgumentParser(description='Evaluate Perplexity with GPT models') | |
parser.add_argument('-m', '--model', help='Choose the model to use (default: VTSTech/Desktop-GPT-111m)', type=str, default="VTSTech/Desktop-GPT-111m") | |
parser.add_argument('-t', '--time', action='store_true', help='Print execution time') | |
parser.add_argument('-cl', '--clean', action='store_true', help='Clean output') | |
parser.add_argument('-nw', '--nowarn', action='store_true', help='Suppress warnings') | |
args = parser.parse_args() | |
model_name="VTSTech/Desktop-GPT-111m" | |
if args.clean or args.nowarn: | |
warnings.simplefilter("ignore") | |
if args.model: | |
model_name = args.model | |
banner() | |
if not args.clean: | |
print("Loading tokenizer and model...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True) | |
model.eval() | |
total_loss = torch.tensor(0.0) | |
total_count = torch.tensor(0.0) | |
class WikitextDataset(Dataset): | |
def __init__(self, data: List[str], tokenizer: AutoTokenizer): | |
#tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
self.examples = tokenizer(data, padding=True, truncation=True, return_tensors='pt', max_length=max_length) | |
def __len__(self): | |
return len(self.examples['input_ids']) | |
def __getitem__(self, idx): | |
return {'input_ids': self.examples['input_ids'][idx], 'attention_mask': self.examples['attention_mask'][idx]} | |
def encode_text(text, max_length=DEFAULT_MAX_LENGTH): | |
# Tokenize the text and truncate the input sequence to max_length | |
inputs = tokenizer.encode_plus(text, add_special_tokens=True, truncation=True, max_length=max_length, return_tensors='pt') | |
# Forward pass through the model | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Extract the output embeddings from the last hidden state | |
embeddings = outputs.hidden_states[-1].mean(dim=1) | |
return embeddings | |
start_time = time.time() | |
# Load the data and encode it | |
if not args.clean: | |
print(Style.BRIGHT + "Loading and encoding data...") | |
with open('train.txt', 'r') as f: | |
data = f.readlines() | |
encoded_text = encode_text(''.join(data)) | |
# Create the dataset and dataloader | |
if not args.clean: | |
print(Style.BRIGHT + "Creating dataset and dataloader...") | |
dataset = WikitextDataset(data, tokenizer) | |
dataloader = DataLoader(dataset, batch_size=2) | |
# Compute the perplexity on the dataset | |
if not args.clean: | |
print(Style.BRIGHT + "Computing perplexity...") | |
with torch.no_grad(): | |
for batch in dataloader: | |
input_ids = batch['input_ids'] | |
attention_mask = batch['attention_mask'] | |
outputs = None | |
try: | |
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) | |
except: | |
#print(f"input_ids: {input_ids}") | |
print("model() error") | |
if outputs is not None: | |
loss = outputs.loss | |
count = torch.sum(attention_mask) | |
total_loss += loss.item() * count.item() | |
total_count += count.item() | |
perplexity = torch.exp(torch.true_divide(total_loss, total_count)) | |
end_time = time.time() | |
if not args.clean: | |
print(Style.BRIGHT + 'Perplexity:', perplexity) | |
if args.time and not args.clean: | |
print(Style.BRIGHT + Fore.RED + f"Script finished. Execution time: {end_time - start_time:.2f} seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment