Last active
December 5, 2024 08:42
-
-
Save kklemon/da3716b1a2f40ff3bd5a2955db5625d0 to your computer and use it in GitHub Desktop.
FastAPI + SQLModel + LangChain + Structured Generation Demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import asynccontextmanager | |
import os | |
from fastapi import FastAPI | |
from langchain_openai import ChatOpenAI | |
from sqlmodel import Field, Session, SQLModel, create_engine, select | |
from pydantic.json_schema import SkipJsonSchema | |
# class BaseJoke(SQLModel): | |
# joke: str | |
# class Joke(BaseJoke, table=True): | |
# id: int | None = Field(default=None, primary_key=True) | |
# Perhaps a bit hacky, but we can also annotate fields to exclude for structured generation | |
# with SkipJsonSchema to avoid having to define a reduced base class. | |
class Joke(SQLModel, table=True): | |
id: SkipJsonSchema[int | None] = Field(default=None, primary_key=True) | |
joke: str | |
sqlite_file_name = "database.sqlite" | |
sqlite_url = f"sqlite:///{sqlite_file_name}" | |
openai_api_key = os.environ["OPENAI_API_KEY"] | |
connect_args = {"check_same_thread": False} | |
engine = create_engine(sqlite_url, echo=True, connect_args=connect_args) | |
SQLModel.metadata.create_all(engine) | |
@asynccontextmanager | |
async def lifespan(app: FastAPI): | |
SQLModel.metadata.create_all(engine) | |
yield | |
app = FastAPI(lifespan=lifespan) | |
@app.post("/jokes/generate") | |
def generate_joke(): | |
model = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=1.5) | |
model = model.with_structured_output(Joke, method="json_schema") | |
joke = model.invoke("Tell me a really funny joke. Respond in JSON format.") | |
with Session(engine) as session: | |
session.add(joke) | |
session.commit() | |
session.refresh(joke) | |
return joke | |
@app.get("/jokes/") | |
def read_heroes(): | |
with Session(engine) as session: | |
jokes = session.exec(select(Joke)).all() | |
return jokes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment