Skip to content

Instantly share code, notes, and snippets.

@daveebbelaar
Created August 17, 2024 13:22
Show Gist options
  • Save daveebbelaar/d24eafc6ace1c8f4a091062733b52437 to your computer and use it in GitHub Desktop.
Save daveebbelaar/d24eafc6ace1c8f4a091062733b52437 to your computer and use it in GitHub Desktop.
LLM Factory with Instructor
from typing import Any, Dict, List, Type
import instructor
from anthropic import Anthropic
from config.settings import get_settings
from openai import OpenAI
from pydantic import BaseModel, Field
class LLMFactory:
def __init__(self, provider: str):
self.provider = provider
self.settings = getattr(get_settings(), provider)
self.client = self._initialize_client()
def _initialize_client(self) -> Any:
client_initializers = {
"openai": lambda s: instructor.from_openai(OpenAI(api_key=s.api_key)),
"anthropic": lambda s: instructor.from_anthropic(
Anthropic(api_key=s.api_key)
),
"llama": lambda s: instructor.from_openai(
OpenAI(base_url=s.base_url, api_key=s.api_key),
mode=instructor.Mode.JSON,
),
}
initializer = client_initializers.get(self.provider)
if initializer:
return initializer(self.settings)
raise ValueError(f"Unsupported LLM provider: {self.provider}")
def create_completion(
self, response_model: Type[BaseModel], messages: List[Dict[str, str]], **kwargs
) -> Any:
completion_params = {
"model": kwargs.get("model", self.settings.default_model),
"temperature": kwargs.get("temperature", self.settings.temperature),
"max_retries": kwargs.get("max_retries", self.settings.max_retries),
"max_tokens": kwargs.get("max_tokens", self.settings.max_tokens),
"response_model": response_model,
"messages": messages,
}
return self.client.chat.completions.create(**completion_params)
if __name__ == "__main__":
class CompletionModel(BaseModel):
response: str = Field(description="Your response to the user.")
reasoning: str = Field(description="Explain your reasoning for the response.")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "If it takes 2 hours to dry 1 shirt out in the sun, how long will it take to dry 5 shirts?",
},
]
llm = LLMFactory("openai")
completion = llm.create_completion(
response_model=CompletionModel,
messages=messages,
)
assert isinstance(completion, CompletionModel)
print(f"Response: {completion.response}\n")
print(f"Reasoning: {completion.reasoning}")
from typing import Optional
from pydantic_settings import BaseSettings
from functools import lru_cache
from dotenv import load_dotenv
import os
load_dotenv()
class LLMProviderSettings(BaseSettings):
temperature: float = 0.0
max_tokens: Optional[int] = None
max_retries: int = 3
class OpenAISettings(LLMProviderSettings):
api_key: str = os.getenv("OPENAI_API_KEY")
default_model: str = "gpt-4o"
class AnthropicSettings(LLMProviderSettings):
api_key: str = os.getenv("ANTHROPIC_API_KEY")
default_model: str = "claude-3-5-sonnet-20240620"
max_tokens: int = 1024
class LlamaSettings(LLMProviderSettings):
api_key: str = "key" # required, but not used
default_model: str = "llama3"
base_url: str = "http://localhost:11434/v1"
class Settings(BaseSettings):
app_name: str = "GenAI Project Template"
openai: OpenAISettings = OpenAISettings()
anthropic: AnthropicSettings = AnthropicSettings()
llama: LlamaSettings = LlamaSettings()
@lru_cache
def get_settings():
return Settings()
@harshdy
Copy link

harshdy commented Dec 9, 2024

hey @daveebbelaar I'm also trying to create a template for genai project, I was wondering If you can share your template with us that will be very helpful for us to customize that for our requirements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment