Skip to content

Instantly share code, notes, and snippets.

@daveebbelaar
Created August 17, 2024 13:22
Show Gist options
  • Save daveebbelaar/d24eafc6ace1c8f4a091062733b52437 to your computer and use it in GitHub Desktop.
Save daveebbelaar/d24eafc6ace1c8f4a091062733b52437 to your computer and use it in GitHub Desktop.
LLM Factory with Instructor
from typing import Any, Dict, List, Type
import instructor
from anthropic import Anthropic
from config.settings import get_settings
from openai import OpenAI
from pydantic import BaseModel, Field
class LLMFactory:
def __init__(self, provider: str):
self.provider = provider
self.settings = getattr(get_settings(), provider)
self.client = self._initialize_client()
def _initialize_client(self) -> Any:
client_initializers = {
"openai": lambda s: instructor.from_openai(OpenAI(api_key=s.api_key)),
"anthropic": lambda s: instructor.from_anthropic(
Anthropic(api_key=s.api_key)
),
"llama": lambda s: instructor.from_openai(
OpenAI(base_url=s.base_url, api_key=s.api_key),
mode=instructor.Mode.JSON,
),
}
initializer = client_initializers.get(self.provider)
if initializer:
return initializer(self.settings)
raise ValueError(f"Unsupported LLM provider: {self.provider}")
def create_completion(
self, response_model: Type[BaseModel], messages: List[Dict[str, str]], **kwargs
) -> Any:
completion_params = {
"model": kwargs.get("model", self.settings.default_model),
"temperature": kwargs.get("temperature", self.settings.temperature),
"max_retries": kwargs.get("max_retries", self.settings.max_retries),
"max_tokens": kwargs.get("max_tokens", self.settings.max_tokens),
"response_model": response_model,
"messages": messages,
}
return self.client.chat.completions.create(**completion_params)
if __name__ == "__main__":
class CompletionModel(BaseModel):
response: str = Field(description="Your response to the user.")
reasoning: str = Field(description="Explain your reasoning for the response.")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "If it takes 2 hours to dry 1 shirt out in the sun, how long will it take to dry 5 shirts?",
},
]
llm = LLMFactory("openai")
completion = llm.create_completion(
response_model=CompletionModel,
messages=messages,
)
assert isinstance(completion, CompletionModel)
print(f"Response: {completion.response}\n")
print(f"Reasoning: {completion.reasoning}")
from typing import Optional
from pydantic_settings import BaseSettings
from functools import lru_cache
from dotenv import load_dotenv
import os
load_dotenv()
class LLMProviderSettings(BaseSettings):
temperature: float = 0.0
max_tokens: Optional[int] = None
max_retries: int = 3
class OpenAISettings(LLMProviderSettings):
api_key: str = os.getenv("OPENAI_API_KEY")
default_model: str = "gpt-4o"
class AnthropicSettings(LLMProviderSettings):
api_key: str = os.getenv("ANTHROPIC_API_KEY")
default_model: str = "claude-3-5-sonnet-20240620"
max_tokens: int = 1024
class LlamaSettings(LLMProviderSettings):
api_key: str = "key" # required, but not used
default_model: str = "llama3"
base_url: str = "http://localhost:11434/v1"
class Settings(BaseSettings):
app_name: str = "GenAI Project Template"
openai: OpenAISettings = OpenAISettings()
anthropic: AnthropicSettings = AnthropicSettings()
llama: LlamaSettings = LlamaSettings()
@lru_cache
def get_settings():
return Settings()
@jd-solanki
Copy link

jd-solanki commented Aug 22, 2024

Here's Snippets with improved types. It still requires some efforts to make it 100% type safe.

from collections.abc import Generator
from typing import Any, Literal, Protocol, overload

import instructor
from anthropic import Anthropic
from openai import OpenAI
from pydantic import BaseModel

from api.custom_types import LiteralFalse, LiteralTrue
from api.settings import AnthropicSettings, OllamaSettings, OpenAISettings, get_settings

type LLMProviders = Literal["ollama", "openai", "anthropic"]
type LLMSettings = OpenAISettings | AnthropicSettings | OllamaSettings


class ClientInitializerCallback(Protocol):
    def __call__(self, settings: LLMSettings) -> instructor.Instructor: ...


type ClientInitializer = dict[LLMProviders, ClientInitializerCallback]


class LLMFactory:
    def __init__(self, provider: LLMProviders) -> None:
        self.provider: LLMProviders = provider
        self.settings: LLMSettings = getattr(get_settings(), provider)
        self.client: instructor.Instructor = self._initialize_client()

    def _initialize_client(self) -> instructor.Instructor:
        client_initializers: ClientInitializer = {
            "openai": lambda settings: instructor.from_openai(OpenAI(api_key=settings.api_key)),
            "anthropic": lambda settings: instructor.from_anthropic(Anthropic(api_key=settings.api_key)),
            "ollama": lambda settings: instructor.from_openai(
                OpenAI(base_url=settings.base_url, api_key=settings.api_key),  # type: ignore Ollama setting will have `settings.base_url`
                mode=instructor.Mode.JSON,
            ),
        }

        initializer = client_initializers.get(self.provider)
        if initializer:
            return initializer(self.settings)

        err_msg = f"Unsupported LLM provider: {self.provider}"
        raise ValueError(err_msg)

    def create_completion[T: type[BaseModel]](
        self,
        response_model: T,
        messages: list[dict[str, str]],
        **kwargs: Any,
    ) -> T | Generator[T, None, None]:
        completion_params = {
            "model": kwargs.get("model") or self.settings.default_model,
            "temperature": kwargs.get("temperature", self.settings.temperature),
            "max_retries": kwargs.get("max_retries", self.settings.max_retries),
            "max_tokens": kwargs.get("max_tokens", self.settings.max_tokens),
            "response_model": response_model,
            "messages": messages,
        }

        return self.client.chat.completions.create(**completion_params)  # type: ignore Instructor needs to improve type hints
from functools import lru_cache
from typing import Literal

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from api.paths import root_dir

env_file_path = root_dir / ".env"


class LLMProviderSettings(BaseSettings):
    model_config = SettingsConfigDict(env_file=env_file_path, env_file_encoding="utf-8", extra="ignore")

    temperature: float = 0.0
    max_tokens: int | None = None
    max_retries: int = 3


class OpenAISettings(LLMProviderSettings):
    api_key: str | None = Field(alias="OPENAI_API_KEY", default=None)
    default_model: str = "gpt-4o"


class AnthropicSettings(LLMProviderSettings):
    api_key: str | None = Field(alias="ANTHROPIC_API_KEY", default=None)
    default_model: str = "claude-3-5-sonnet-20240620"
    max_tokens: int | None = 1024


class OllamaSettings(LLMProviderSettings):
    api_key: str = "key"  # required, but not used
    default_model: str = "llama3.1"
    base_url: str = "http://localhost:11434/v1"


class Settings(BaseSettings):
    model_config = SettingsConfigDict(env_file=env_file_path, env_file_encoding="utf-8")

    app_name: str = "GenAI Project Template"
    openai: OpenAISettings = OpenAISettings()
    anthropic: AnthropicSettings = AnthropicSettings()
    ollama: OllamaSettings = OllamaSettings()


@lru_cache
def get_settings():
    return Settings()

@harshdy
Copy link

harshdy commented Dec 9, 2024

hey @daveebbelaar I'm also trying to create a template for genai project, I was wondering If you can share your template with us that will be very helpful for us to customize that for our requirements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment