Last active
January 21, 2025 15:38
-
-
Save archydeberker/655f8520080a39dabde43785299a39e9 to your computer and use it in GitHub Desktop.
Demonstrate the repository pattern for session mgmt in FastAPI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from uuid import UUID | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import Session, sessionmaker | |
from fastapi import FastAPI, Depends, HTTPException | |
from sqlalchemy.orm import Session | |
from pydantic import BaseModel | |
class User(BaseModel): | |
# This is a DB model - in SQLModel you can return the ORM model directly bc it's Pydantic under the hood | |
# For SQLAlchemy you need to serialize it | |
id: UUID | |
name: str | |
class DbClient: | |
"""Handles connection to the DB, engine & sessions""" | |
def __init__(self, db_url: str): | |
# See https://docs.sqlalchemy.org/en/20/core/engines.html | |
self.db_engine = create_engine(db_url) | |
self.session_maker = sessionmaker(self.db_engine) | |
def get_session(self) -> Session: | |
# See https://docs.sqlalchemy.org/en/20/orm/session_api.html#session-api | |
return self.session_maker() | |
class DbRepo: | |
"""Handles Db interactions, using the db_client""" | |
def __init__(self, db_client: DbClient): | |
self.db = db_client | |
def get_user_by_username(self, user_name: str) -> User: | |
with self.db.get_session() as session: | |
return session.query(User).filter(User.name == user_name).first() | |
def get_users(self) -> List[User]: | |
with self.db.get_session() as session: | |
return session.query(User).all() | |
def get_db_repo() -> DbRepo: | |
"""Get a DbRepo initialized with a DbClient for the current env""" | |
db_url = os.environ.get("DB_CONNECTION_STR") | |
return DbRepo(db_client=DbClient(db_url=db_url)) | |
def get_current_user(request, repo: DbRepo = Depends(get_db_repo)) -> User: | |
# This uses the repo as a dependency. This is nice because the dependency will be cached within the request, | |
# so when we use the repo directly within an API handler, we're using exactly the same instance of the repo | |
# as this one uses. See https://fastapi.tiangolo.com/tutorial/dependencies/sub-dependencies/#using-the-same-dependency-multiple-times | |
username = request.username # However you're identifying the user, get it from the request | |
return repo.get_user_by_username(username) | |
def current_user_has_permission(user: User) -> bool: | |
# Do some kind of authorization check here | |
return True | |
app = FastAPI() | |
@app.get("/users") | |
def get_users(current_user: User = Depends(get_current_user), repo: DbRepo = Depends(get_db_repo)) -> List[User]: | |
if not current_user_has_permission(current_user): | |
raise HTTPException(status_code=403) | |
return repo.get_users() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment