Created
May 13, 2024 19:42
-
-
Save JoshuaPurtell/75861bfc513725382f3149c591433e56 to your computer and use it in GitHub Desktop.
How does GPT-4O's internal state tracking stack up?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import os | |
import random | |
import hashlib | |
from datetime import datetime | |
from typing import Dict, List, Type | |
from dotenv import load_dotenv | |
from loguru import logger | |
from pydantic import BaseModel | |
from diskcache import Cache | |
from openai import AsyncOpenAI | |
from anthropic import AsyncAnthropic | |
from together import AsyncTogether | |
from groq import AsyncGroq | |
import instructor | |
import loguru | |
# Initialize logger and load environment variables | |
logger = loguru.logger | |
load_dotenv() | |
# Create a cache object | |
cache = Cache(directory=".cache") | |
def generate_cache_key(messages: List[Dict], model: str) -> str: | |
key = "".join(msg["content"] for msg in messages) + model | |
return hashlib.sha256(key.encode()).hexdigest() | |
def generate_cache_key_with_response_model(messages: List[Dict], model: str, response_model: Type[BaseModel]) -> str: | |
key = "".join(msg["content"] for msg in messages) + model + str(response_model.schema()) | |
return hashlib.sha256(key.encode()).hexdigest() | |
# Clients initialization | |
openai_client = instructor.patch(AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))) | |
anthropic_client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | |
tgi_client = AsyncTogether(api_key=os.getenv("TOGETHER_AI_API_KEY")) | |
groq_client = instructor.patch(AsyncGroq(api_key=os.getenv("GROQ_API_KEY")), mode=instructor.Mode.MD_JSON) | |
async def chat_completion(client, messages: List[Dict], model: str, temperature: float, max_tokens: int, response_model: Type[BaseModel] = None): | |
key = generate_cache_key_with_response_model(messages, model, response_model) if response_model else generate_cache_key(messages, model) | |
if key in cache: | |
return response_model.parse_raw(cache[key]) if response_model else cache[key] | |
if response_model and 'claude' not in model: | |
response = await client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
response_model=response_model | |
) | |
result = response.json() | |
output = response_model.parse_raw(result) | |
elif 'claude' in model: | |
response = await client.messages.create( | |
model=model, | |
system=messages[0]["content"], | |
messages=messages[1:], | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
result = output = response.content[0].text | |
else: | |
response = await client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
result = output = response.choices[0].message.content | |
cache[key] = result | |
return output | |
def sync_chat_completion(client, messages: List[Dict], model: str, temperature: float = 0.0, max_tokens: int = 150, response_model: Type[BaseModel] = None): | |
return asyncio.run(chat_completion(client, messages, model, temperature, max_tokens, response_model)) | |
def build_messages(sys_msg: str, user_msg: str) -> List[Dict]: | |
return [{"role": "system", "content": sys_msg}, {"role": "user", "content": user_msg}] | |
class LLM: | |
def __init__(self, model_name: str, temperature: float = 0.0, max_tokens: int = 150, response_model: Type[BaseModel] = None): | |
self.model_name = model_name | |
self.temperature = temperature | |
self.max_tokens = max_tokens | |
self.response_model = response_model | |
self.client = self.determine_client() | |
def determine_client(self): | |
if "gpt" in self.model_name: | |
return openai_client | |
elif "claude" in self.model_name: | |
return anthropic_client | |
elif "llama" in self.model_name: | |
return groq_client | |
else: | |
return tgi_client | |
async def respond(self, system_prompt: str, user_prompt: str): | |
messages = build_messages(system_prompt, user_prompt) | |
return await chat_completion(self.client, messages, self.model_name, self.temperature, self.max_tokens, self.response_model) | |
def create_synthetic_data(k=100,trial=0): | |
random.seed(420+trial) | |
counterparty_names = ["Google", "Apple", "Microsoft", ...] # truncated for brevity | |
synthetic_data = [] | |
for _ in range(k): | |
cnp = random.choice(counterparty_names) | |
date = datetime(2023, random.randint(1, 12), random.randint(1, 28)) | |
amount = random.randint(1000, 10000) | |
synthetic_data.append({"counterparty_name": cnp, "amount": amount, "date": date}) | |
synthetic_data.append({"counterparty_name": cnp, "amount": -amount, "date": date}) | |
needle = {"counterparty_name": random.choice(counterparty_names), "amount": random.randint(1000, 10000), "date": datetime(2023, random.randint(1, 12), random.randint(1, 28))} | |
synthetic_data.append(needle) | |
random.shuffle(synthetic_data) | |
return "\n".join(f"{data['counterparty_name']} {data['amount']} {data['date']}" for data in synthetic_data), needle | |
async def check_correctness(stringified_haystack, needle, llm: LLM): | |
completion = await llm.respond(""" | |
# Premise | |
You will be provided with records of accounting entries. Some represent real-world transactions, and others represent offsetting entries. | |
Each real-world transaction ought to have an offsetting entry to balance the books. | |
Matching entries share the following characteristics: | |
- Same counterparty name | |
- Same date | |
- Amount with the same absolute value but opposite sign | |
## Examples of Matching Entries | |
### Matching Pair 1 | |
Google 1000 2023-01-01 | |
Google -1000 2023-01-01 | |
### Matching Pair 2 | |
Apple 2000 2023-02-01 | |
Apple -2000 2023-02-01 | |
### Matching Pair 3 | |
Microsoft 3000 2023-03-01 | |
Microsoft -3000 2023-03-01 | |
# Objective | |
Identify the entry that does not have an offsetting entry. Respond only with its information, in the same format as it is presented. | |
""", "The entries you have to pick from:" + stringified_haystack) | |
correctness = ( | |
(str(needle["counterparty_name"]) in completion) | |
and (str(needle["amount"]) in completion) | |
and (str(needle["date"]) in completion) | |
) | |
return correctness, completion | |
async def full_eval_for_model(model="gpt-4o", dataset_sizes=[10, 25, 50, 100]): | |
llm = LLM(model_name=model) | |
n_trials = 3 | |
results = {} | |
last_viable_k = None | |
for k in dataset_sizes: | |
results[k] = {} | |
results[k]["prcntg_trials_passed"] = 0 | |
for trial in range(n_trials): | |
stringified, needle = create_synthetic_data(k=k, trial=trial) | |
correctness_for_k, full_completion = await check_correctness(stringified, needle, llm) | |
results[k]["prcntg_trials_passed"] += correctness_for_k | |
results[k]["prcntg_trials_passed"] /= n_trials | |
if results[k]["prcntg_trials_passed"] == 0: | |
break | |
else: | |
last_viable_k = k | |
print("Passed for k = ", k) | |
return results, last_viable_k | |
if __name__ == "__main__": | |
dataset_sizes = [10, 15, 20, 25, 30, 50, 75, 85, 95, 100, 125, 150, 200, 500, 1000, 2000] | |
model = "gpt-4o" | |
results, last_viable_k = asyncio.run(full_eval_for_model(model=model, dataset_sizes=dataset_sizes)) | |
print(last_viable_k) | |
print(results) | |
# K pairs at which the LLM passes / fails (pass means it got 1/3 tries correct or better) | |
# OpenAI models | |
# gpt-4-32k: 95/100 {10: {'prcntg_trials_passed': 1.0}, 15: {'prcntg_trials_passed': 1.0}, 20: {'prcntg_trials_passed': 1.0}, 25: {'prcntg_trials_passed': 1.0}, 30: {'prcntg_trials_passed': 1.0}, 50: {'prcntg_trials_passed': 0.6666666666666666}, 75: {'prcntg_trials_passed': 0.3333333333333333}, 85: {'prcntg_trials_passed': 0.6666666666666666}, 100: {'prcntg_trials_passed': 0.0}} | |
# gpt-4-turbo: 75/85 {10: {'prcntg_trials_passed': 1.0}, 15: {'prcntg_trials_passed': 1.0}, 20: {'prcntg_trials_passed': 1.0}, 25: {'prcntg_trials_passed': 0.6666666666666666}, 30: {'prcntg_trials_passed': 0.6666666666666666}, 50: {'prcntg_trials_passed': 0.6666666666666666}, 75: {'prcntg_trials_passed': 0.3333333333333333}, 85: {'prcntg_trials_passed': 0.0}} | |
# gpt-4o: 20/25 {10: {'prcntg_trials_passed': 0.6666666666666666}, 15: {'prcntg_trials_passed': 0.6666666666666666}, 20: {'prcntg_trials_passed': 1.0}, 25: {'prcntg_trials_passed': 0.0}} | |
# gpt-3.5-turbo: 20/25 {10: {'prcntg_trials_passed': 0.6666666666666666}, 15: {'prcntg_trials_passed': 0.6666666666666666}, 20: {'prcntg_trials_passed': 0.3333333333333333}, 25: {'prcntg_trials_passed': 0.0}} | |
# Meta models | |
# llama-3-70b 20/25 {10: {'prcntg_trials_passed': 0.6666666666666666}, 15: {'prcntg_trials_passed': 0.6666666666666666}, 20: {'prcntg_trials_passed': 0.6666666666666666}, 25: {'prcntg_trials_passed': 0.0}} | |
# llama-3-8b 10/15 | |
# Anthropic models | |
# claude-3-opus: 95/100 {10: {'prcntg_trials_passed': 1.0}, 15: {'prcntg_trials_passed': 1.0}, 20: {'prcntg_trials_passed': 1.0}, 25: {'prcntg_trials_passed': 0.6666666666666666}, 30: {'prcntg_trials_passed': 1.0}, 50: {'prcntg_trials_passed': 0.6666666666666666}, 75: {'prcntg_trials_passed': 0.3333333333333333}, 85: {'prcntg_trials_passed': 0.6666666666666666}, 95: {'prcntg_trials_passed': 0.6666666666666666}, 100: {'prcntg_trials_passed': 0.0} | |
# claude-3-sonnet: 100/125 {10: {'prcntg_trials_passed': 0.3333333333333333}, 15: {'prcntg_trials_passed': 0.6666666666666666}, 20: {'prcntg_trials_passed': 0.6666666666666666}, 25: {'prcntg_trials_passed': 0.6666666666666666}, 30: {'prcntg_trials_passed': 0.3333333333333333}, 50: {'prcntg_trials_passed': 0.6666666666666666}, 75: {'prcntg_trials_passed': 0.6666666666666666}, 85: {'prcntg_trials_passed': 0.6666666666666666}, 95: {'prcntg_trials_passed': 0.6666666666666666}, 100: {'prcntg_trials_passed': 0.3333333333333333}, 125: {'prcntg_trials_passed': 0.0}} | |
# claude-3-haiku: 20/25 {10: {'prcntg_trials_passed': 0.3333333333333333}, 15: {'prcntg_trials_passed': 0.3333333333333333}, 20: {'prcntg_trials_passed': 0.6666666666666666}, 25: {'prcntg_trials_passed': 0.0}} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment