Created
March 26, 2025 15:46
-
-
Save jerrylususu/3e9bb4ca79e7595239d07f07367b3996 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import time | |
from datetime import datetime | |
from openai import OpenAI | |
import argparse | |
from typing import Dict, List, Any, Optional, Union, TypedDict, cast | |
from dataclasses import dataclass | |
@dataclass | |
class TestCase: | |
name: str | |
input: str | |
expected: Dict[str, Any] | |
@dataclass | |
class CaseResult: | |
name: str | |
input: str | |
expected_output: Dict[str, Any] | |
actual_output: Dict[str, Any] | |
score: float | |
@dataclass | |
class EvaluationResults: | |
timestamp: str | |
prompt: str | |
prompt_path: str | |
total_score: float | |
cases: List[CaseResult] | |
def load_prompt(prompt_path: str) -> str: | |
"""Load prompt from a file.""" | |
with open(prompt_path, 'r', encoding='utf-8') as f: | |
return f.read().strip() | |
def load_test_cases(test_folder: str) -> List[TestCase]: | |
"""Load test cases from a folder. Each case consists of an input.txt and expected.json file.""" | |
test_cases: List[TestCase] = [] | |
# Get all txt files that have matching json files | |
for filename in os.listdir(test_folder): | |
if filename.endswith('.txt'): | |
base_name = filename[:-4] # Remove .txt extension | |
json_file = f"{base_name}.json" | |
if os.path.exists(os.path.join(test_folder, json_file)): | |
# Found a matching pair | |
with open(os.path.join(test_folder, filename), 'r', encoding='utf-8') as f: | |
input_text = f.read().strip() | |
with open(os.path.join(test_folder, json_file), 'r', encoding='utf-8') as f: | |
expected_output = json.load(f) | |
test_cases.append(TestCase( | |
name=base_name, | |
input=input_text, | |
expected=expected_output | |
)) | |
return test_cases | |
def run_llm(client: OpenAI, prompt: str, input_text: str) -> str: | |
"""Run the LLM with the given prompt and input.""" | |
# Combine prompt and input | |
full_prompt = f"{prompt}\n\n{input_text}" | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", # Can be customized | |
messages=[ | |
{"role": "user", "content": full_prompt} | |
] | |
) | |
return response.choices[0].message.content | |
def score(actual_output: Union[str, Dict[str, Any]], expected_output: Dict[str, Any]) -> float: | |
"""Score the actual output against the expected output. | |
This is a simple implementation that should be customized based on your specific needs. | |
""" | |
# Convert string to JSON if needed | |
actual_json: Dict[str, Any] = {} | |
if isinstance(actual_output, str): | |
try: | |
actual_json = json.loads(actual_output) | |
except json.JSONDecodeError: | |
# If output isn't valid JSON, score is 0 | |
return 0.0 | |
else: | |
actual_json = actual_output | |
# Simple scoring: check if keys match | |
expected_keys = set(expected_output.keys()) | |
actual_keys = set(actual_json.keys()) | |
# Calculate percentage of matching keys | |
if not expected_keys: | |
return 1.0 # If no expected keys, perfect score | |
common_keys = expected_keys.intersection(actual_keys) | |
key_score = len(common_keys) / len(expected_keys) | |
# For common keys, check values | |
value_scores: List[float] = [] | |
for key in common_keys: | |
expected_val = expected_output[key] | |
actual_val = actual_json.get(key) | |
if expected_val == actual_val: | |
value_scores.append(1.0) | |
else: | |
value_scores.append(0.0) | |
# Average value score | |
avg_value_score = sum(value_scores) / len(value_scores) if value_scores else 1.0 | |
# Final score is average of key match and value match | |
return (key_score + avg_value_score) / 2 | |
def evaluate_prompt(prompt_path: str, test_folder: str, api_key: Optional[str] = None) -> EvaluationResults: | |
"""Evaluate a prompt on test cases and return results.""" | |
# Initialize OpenAI client | |
client = OpenAI(api_key=api_key) | |
# Load prompt and test cases | |
prompt = load_prompt(prompt_path) | |
test_cases = load_test_cases(test_folder) | |
results = EvaluationResults( | |
timestamp=datetime.now().isoformat(), | |
prompt=prompt, | |
prompt_path=prompt_path, | |
total_score=0.0, | |
cases=[] | |
) | |
# Process each test case | |
for case in test_cases: | |
actual_output = run_llm(client, prompt, case.input) | |
# Try to parse the output as JSON | |
try: | |
actual_json = json.loads(actual_output) | |
except json.JSONDecodeError: | |
actual_json = {"error": "Invalid JSON output"} | |
case_score = score(actual_output, case.expected) | |
results.cases.append(CaseResult( | |
name=case.name, | |
input=case.input, | |
expected_output=case.expected, | |
actual_output=actual_json, | |
score=case_score | |
)) | |
results.total_score += case_score | |
# Calculate average score | |
if test_cases: | |
results.total_score /= len(test_cases) | |
return results | |
def save_results(results: EvaluationResults, output_path: str) -> None: | |
"""Save evaluation results to a file.""" | |
with open(output_path, 'w', encoding='utf-8') as f: | |
json.dump(results.__dict__, f, indent=2, default=lambda o: o.__dict__ if hasattr(o, '__dict__') else o) | |
def main() -> None: | |
parser = argparse.ArgumentParser(description='Evaluate prompts using test cases') | |
parser.add_argument('--prompt', required=True, help='Path to the prompt file') | |
parser.add_argument('--tests', required=True, help='Path to the folder containing test cases') | |
parser.add_argument('--output', default='results.json', help='Path to save results (default: results.json)') | |
parser.add_argument('--api-key', help='OpenAI API key (or set OPENAI_API_KEY environment variable)') | |
args = parser.parse_args() | |
# Use API key from args or environment variable | |
api_key = args.api_key or os.environ.get('OPENAI_API_KEY') | |
if not api_key: | |
print("Error: OpenAI API key not provided. Use --api-key or set OPENAI_API_KEY environment variable.") | |
return | |
print(f"Evaluating prompt: {args.prompt}") | |
print(f"Using test cases from: {args.tests}") | |
# Run evaluation | |
results = evaluate_prompt(args.prompt, args.tests, api_key) | |
# Save results | |
save_results(results, args.output) | |
print(f"Evaluation complete! Total score: {results.total_score:.2f}") | |
print(f"Results saved to: {args.output}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment