Skip to content

Instantly share code, notes, and snippets.

@VTSTech
Last active May 3, 2024 01:17
Show Gist options
  • Save VTSTech/924a94abb5e315e8c8aba200d5b9131b to your computer and use it in GitHub Desktop.
Save VTSTech/924a94abb5e315e8c8aba200d5b9131b to your computer and use it in GitHub Desktop.
VTSTech-PERP - Python script that computes perplexity on GPT Models
# Program: VTSTech-PERP.py 2023-04-17 6:14:21PM
# Description: Python script that computes perplexity on GPT Models
# Author: Written by Veritas//VTSTech ([email protected])
# GitHub: https://github.com/VTSTech
# Homepage: www.VTS-Tech.org
# Use a 'train.txt' for it to predict with. Any large english text will do
# pip install torch argparse transformers colorama
import torch
import argparse
import time
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from typing import List
from colorama import Fore, Back, Style, init
init(autoreset=True)
build="v0.1-r01"
DEFAULT_MAX_LENGTH = 512
max_length=DEFAULT_MAX_LENGTH
def banner():
global model_name
if not args.clean:
print(Style.BRIGHT + f"VTSTech-PERP {build} - www: VTS-Tech.org git: VTSTech discord.gg/P4RDD76U")
print("Using Model : " + Fore.RED + f"{model_name}")
# Load the tokenizer and model
parser = argparse.ArgumentParser(description='Evaluate Perplexity with GPT models')
parser.add_argument('-m', '--model', help='Choose the model to use (default: VTSTech/Desktop-GPT-111m)', type=str, default="VTSTech/Desktop-GPT-111m")
parser.add_argument('-t', '--time', action='store_true', help='Print execution time')
parser.add_argument('-cl', '--clean', action='store_true', help='Clean output')
parser.add_argument('-nw', '--nowarn', action='store_true', help='Suppress warnings')
args = parser.parse_args()
model_name="VTSTech/Desktop-GPT-111m"
if args.clean or args.nowarn:
warnings.simplefilter("ignore")
if args.model:
model_name = args.model
banner()
if not args.clean:
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
model.eval()
total_loss = torch.tensor(0.0)
total_count = torch.tensor(0.0)
class WikitextDataset(Dataset):
def __init__(self, data: List[str], tokenizer: AutoTokenizer):
#tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.examples = tokenizer(data, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
def __len__(self):
return len(self.examples['input_ids'])
def __getitem__(self, idx):
return {'input_ids': self.examples['input_ids'][idx], 'attention_mask': self.examples['attention_mask'][idx]}
def encode_text(text, max_length=DEFAULT_MAX_LENGTH):
# Tokenize the text and truncate the input sequence to max_length
inputs = tokenizer.encode_plus(text, add_special_tokens=True, truncation=True, max_length=max_length, return_tensors='pt')
# Forward pass through the model
with torch.no_grad():
outputs = model(**inputs)
# Extract the output embeddings from the last hidden state
embeddings = outputs.hidden_states[-1].mean(dim=1)
return embeddings
start_time = time.time()
# Load the data and encode it
if not args.clean:
print(Style.BRIGHT + "Loading and encoding data...")
with open('train.txt', 'r') as f:
data = f.readlines()
encoded_text = encode_text(''.join(data))
# Create the dataset and dataloader
if not args.clean:
print(Style.BRIGHT + "Creating dataset and dataloader...")
dataset = WikitextDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=2)
# Compute the perplexity on the dataset
if not args.clean:
print(Style.BRIGHT + "Computing perplexity...")
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
outputs = None
try:
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
except:
#print(f"input_ids: {input_ids}")
print("model() error")
if outputs is not None:
loss = outputs.loss
count = torch.sum(attention_mask)
total_loss += loss.item() * count.item()
total_count += count.item()
perplexity = torch.exp(torch.true_divide(total_loss, total_count))
end_time = time.time()
if not args.clean:
print(Style.BRIGHT + 'Perplexity:', perplexity)
if args.time and not args.clean:
print(Style.BRIGHT + Fore.RED + f"Script finished. Execution time: {end_time - start_time:.2f} seconds")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment