Created
December 10, 2024 06:37
-
-
Save ColeMurray/c9b257a79e8dca29ae7b80372c40bf97 to your computer and use it in GitHub Desktop.
A basic openai completion decorator with structured outputs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
from typing import Type, TypeVar, Callable, Any | |
from pydantic import BaseModel | |
from openai import OpenAI | |
T = TypeVar('T', bound=BaseModel) | |
def openai_call(_model: str, response_model: Type[T]) -> Callable: | |
""" | |
Decorator that handles OpenAI API calls using structured output parsing. | |
Args: | |
_model: The OpenAI model to use (e.g., "gpt-4", "gpt-3.5-turbo") | |
response_model: A Pydantic model class that defines the expected response structure | |
""" | |
client = OpenAI() | |
def decorator(func: Callable[..., str]) -> Callable[..., T]: | |
@functools.wraps(func) | |
def wrapper(*args: Any, **kwargs: Any) -> T: | |
# Get the prompt from the decorated function | |
prompt = func(*args, **kwargs) | |
try: | |
# Make the API call using parse endpoint | |
completion = client.beta.chat.completions.parse( | |
model=_model, | |
messages=[ | |
{"role": "system", "content": "Extract the requested information into structured format."}, | |
{"role": "user", "content": prompt} | |
], | |
response_format=response_model, | |
) | |
# Return the parsed response | |
return completion.choices[0].message.parsed | |
except Exception as e: | |
raise RuntimeError(f"OpenAI API call failed: {str(e)}") | |
return wrapper | |
return decorator | |
# Example usage: | |
from pydantic import BaseModel | |
class Capital(BaseModel): | |
city: str | |
country: str | |
@openai_call("gpt-4o", response_model=Capital) | |
def extract_capital(query: str) -> str: | |
return f"Extract the capital city and country from this text: {query}" | |
# Example usage | |
if __name__ == "__main__": | |
try: | |
capital = extract_capital("The capital of France is Paris") | |
print(f"City: {capital.city}") | |
print(f"Country: {capital.country}") | |
except Exception as e: | |
print(f"Error: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment