Last active
December 10, 2023 13:43
-
-
Save Compro-Prasad/87dc9942e94296e4d98a11403c915135 to your computer and use it in GitHub Desktop.
Sqlalchemy base template for FastAPI projects
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import Any, AsyncGenerator, Generator | |
from datetime import datetime | |
from alembic_utils.pg_trigger import PGTrigger | |
from alembic_utils.pg_function import PGFunction | |
from sqlalchemy import create_engine | |
from sqlalchemy import func | |
from sqlalchemy.orm import Mapped as T | |
from sqlalchemy.orm import mapped_column as column | |
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine | |
from sqlalchemy.ext.declarative import as_declarative, declared_attr | |
from sqlalchemy.orm import sessionmaker | |
def updated_at_trigger(tablename): | |
return PGTrigger( | |
schema="public", | |
signature=f"{tablename}_set_updated_at_on_update", | |
on_entity=tablename, | |
definition=f""" | |
BEFORE UPDATE ON {tablename} | |
FOR EACH ROW | |
EXECUTE PROCEDURE set_updated_at(); | |
""", | |
) | |
updated_at_trigger.function = PGFunction( | |
schema="public", | |
signature="set_updated_at()", # Can be reused for any table with column updated_at | |
definition=""" | |
RETURNS TRIGGER AS $$ | |
BEGIN | |
NEW.updated_at := now(); | |
return NEW; | |
END; | |
$$ language 'plpgsql' | |
""", | |
) | |
pat1 = re.compile("[A-Z]{2,}") | |
pat2 = re.compile(r"(?<!^)(?=[A-Z])") | |
@as_declarative() | |
class Base: | |
id: T[int] = column(primary_key=True, autoincrement=True) | |
created_at: T[datetime] = column(server_default=func.now()) | |
updated_at: T[datetime] = column(server_default=func.now()) | |
_name_: str | |
# Generate _tablename_ automatically | |
@declared_attr | |
def _tablename_(cls) -> str: | |
name = cls._name_ | |
assert not pat1.findall(name), "Use proper camel case for model names" | |
return pat2.sub("_", name).lower() | |
engine = create_engine(os.getenv("DATABASE_URL"), pool_pre_ping=True) | |
session_maker = sessionmaker(bind=engine) | |
aio_engine = create_async_engine( | |
os.getenv("DATABASE_URL").replace("postgresql://", "postgresql+asyncpg://"), | |
pool_pre_ping=True, | |
) | |
aio_session_maker = async_sessionmaker(engine) | |
def get_session() -> Generator: | |
try: | |
db = session_maker() | |
yield db | |
finally: | |
db.close() | |
async def get_aio_session() -> AsyncGenerator[AsyncSession, None]: | |
try: | |
db = aio_session_maker() | |
yield db | |
finally: | |
await db.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment