Skip to content

Instantly share code, notes, and snippets.

@AndBondStyle
Last active October 2, 2024 22:28
Show Gist options
  • Save AndBondStyle/3e0952291a55720f8c6bf57e5e51cc44 to your computer and use it in GitHub Desktop.
Save AndBondStyle/3e0952291a55720f8c6bf57e5e51cc44 to your computer and use it in GitHub Desktop.
FastAPI + Dependency Injector + SQLAlchemy
import asyncio
import os
from contextlib import asynccontextmanager
from typing import Any
import sqlalchemy as sa
from dependency_injector import providers
from dependency_injector.containers import DeclarativeContainer
from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends, FastAPI
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
async def init_db_engine():
dsn = os.environ["POSTGRES_DSN"]
engine = create_async_engine(dsn, echo=True)
print("engine start")
yield engine
print("engine stop")
await engine.dispose()
class Container(DeclarativeContainer):
db_engine = providers.Resource(init_db_engine)
db_session_factory = providers.Resource(async_sessionmaker, db_engine)
db_scoped_session = providers.ThreadSafeSingleton(
async_scoped_session,
session_factory=db_session_factory,
scopefunc=asyncio.current_task,
)
db_session = providers.Object(None) # dummy provider
something = providers.Factory(lambda: 123) # example of regular dependency
session_factory = Provide["db_scoped_session"]
# Async generator to use directly with fastapi's `Depends(...)`
async def init_session():
session = (await session_factory)()
async with session:
print("session before")
yield session
print("session after")
# This replaces the `Depends(Provide[...])` with just `Dep(...)`
# For `db_session` we want to pass `init_session` function directly, avoiding PDI
# Not an elegant solution, but works fine and adds no overhead
def wrap_dependency(dependency: Any) -> Any:
if dependency is Container.db_session or dependency == "db_session":
return Depends(init_session)
return Depends(Provide[dependency])
Dep = wrap_dependency # shortcut
# This function patches the `APIRouter.api_route` so that PDI's `@inject` decorator
# is added for every view. This needs to be called before any views are defined
def fasatpi_auto_inject():
original = APIRouter.api_route
def api_route_patched(self, *args, **kwargs):
print("api route patched:", kwargs.get("path"))
decorator = original(self, *args, **kwargs)
# Composition of two decorators
return lambda func: decorator(inject(func))
APIRouter.api_route = api_route_patched # type: ignore
# Call before any view definitions
fasatpi_auto_inject()
@asynccontextmanager
async def lifespan(app: FastAPI):
container = Container()
container.wire(modules=[__name__])
await container.init_resources() # type: ignore
yield
await container.shutdown_resources() # type: ignore
app = FastAPI(lifespan=lifespan)
@app.get("/test") # @inject decorator implicitly added
async def test(
db: AsyncSession = Dep(Container.db_session), # expands to: Depends(init_session)
something: int = Dep(Container.something), # expands to: Depends(Provide[...])
):
res = await db.execute(sa.text("select version()"))
return {
"version": res.scalar(),
"something": something,
}
@AndBondStyle
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment