from __future__ import annotations import inspect from functools import wraps from typing import List, Optional, Type from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import Session, SQLModel, create_engine, select from sqlmodel.ext.asyncio.session import AsyncSession class BaseClient: def __init__(self, connection_string: str): self.engine = create_engine(connection_string) self.models: list[Type[SQLModel]] = [] class Client(BaseClient): def create_db_and_tables(self): """Creates the db and all the required tables based on the sqlmodels""" SQLModel.metadata.create_all(self.engine) class AsyncClient(BaseClient): def __init__(self, connection_string: str): super().__init__(connection_string) self.async_engine = create_async_engine(connection_string) async def create_db_and_tables(self): """Creates the db and all the required tables based on the sqlmodels""" async with self.async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) def create_method_with_model_signature(func, model: Type[SQLModel]): @wraps(func) def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) # Get the original signature sig = inspect.signature(func) # Get the model's fields model_fields = model.__fields__ # Create new parameters based on the model's fields new_params = [ inspect.Parameter( "self", inspect.Parameter.POSITIONAL_ONLY, ) ] for field_name, field in model_fields.items(): if field.is_required(): param = inspect.Parameter( field_name, inspect.Parameter.KEYWORD_ONLY, annotation=field.annotation, ) else: param = inspect.Parameter( field_name, inspect.Parameter.KEYWORD_ONLY, annotation=field.annotation, default=field.default, ) new_params.append(param) wrapper.__signature__ = sig.replace(parameters=new_params) return wrapper def generate_async_client_class(db_models: List[Type[SQLModel]]) -> Type[AsyncClient]: for model in db_models: create_async_methods(AsyncClient, model) return AsyncClient def generate_sync_client_class(db_models: List[Type[SQLModel]]) -> Type[Client]: for model in db_models: create_sync_methods(Client, model) return Client def create_expression_from_kwargs(model, **kwargs): for key, value in kwargs.items(): if isinstance(value, tuple): yield value[1](model.__dict__[key], value[0]) elif isinstance(value, list): yield model.__dict__[key].in_(value) else: yield model.__dict__[key] == value def create_async_methods(AsyncClient: Type[AsyncClient], model: Type[SQLModel]): model_name = model.__name__ lower_model_name = model_name.lower() # Generate create method def create_create_method(model): async def create_method(self: AsyncClient, *args, **kwargs) -> model: async with AsyncSession(self.async_engine) as session: db_item = model(*args, **kwargs) session.add(db_item) await session.commit() await session.refresh(db_item) return db_item return create_method_with_model_signature(create_method, model) setattr(AsyncClient, f"create_{lower_model_name}", create_create_method(model)) # Generate get method def create_get_method(model): async def get_method(self: AsyncClient, **kwargs) -> Optional[model]: async with AsyncSession(self.async_engine) as session: statement = select(model).where( *create_expression_from_kwargs(model, **kwargs) ) result = await session.exec(statement) return result.one() return create_method_with_model_signature(get_method, model) setattr(AsyncClient, f"get_{lower_model_name}", create_get_method(model)) # Generate get_all method def create_get_all_method(model): async def get_all_method(self: AsyncClient) -> List[model]: async with AsyncSession(self.async_engine) as session: statement = select(model) result = await session.exec(statement) return result.all() return create_method_with_model_signature(get_all_method, model) setattr(AsyncClient, f"get_all_{lower_model_name}s", create_get_all_method(model)) # Generate update method def create_update_method(model): async def update_method(self: AsyncClient, **kwargs) -> Optional[model]: async with AsyncSession(self.async_engine) as session: statement = select(model).where( *create_expression_from_kwargs(model, **kwargs) ) result = await session.exec(statement) db_item = result.one() if db_item: for key, value in kwargs.items(): setattr(db_item, key, value) await session.commit() await session.refresh(db_item) return db_item return create_method_with_model_signature(update_method, model) setattr(AsyncClient, f"update_{lower_model_name}", create_update_method(model)) # Generate delete method def create_delete_method(model): async def delete_method(self: AsyncClient, **kwargs) -> bool: async with AsyncSession(self.async_engine) as session: statement = select(model).where( *create_expression_from_kwargs(model, **kwargs) ) result = await session.exec(statement) db_item = result.one() if db_item: await session.delete(db_item) await session.commit() return True return False return create_method_with_model_signature(delete_method, model) setattr(AsyncClient, f"delete_{lower_model_name}", create_delete_method(model)) def create_sync_methods(Client: Type[Client], model: Type[SQLModel]): model_name = model.__name__ lower_model_name = model_name.lower() # Generate get method def create_get_method(model): def get_method(self: Client, **kwargs) -> Optional[model]: with Session(self.engine) as session: statement = select(model).where( *create_expression_from_kwargs(model, **kwargs) ) result = session.exec(statement) return result.one_or_none() return create_method_with_model_signature(get_method, model) setattr(Client, f"get_{lower_model_name}", create_get_method(model)) # Generate get_all method def create_get_all_method(model): def get_all_method(self: Client) -> List[model]: with Session(self.engine) as session: statement = select(model) result = session.exec(statement).all() return result return create_method_with_model_signature(get_all_method, model) setattr(Client, f"get_all_{lower_model_name}s", create_get_all_method(model)) # Generate create method def create_create_method(model): def create_method(self: Client, **kwargs) -> model: with Session(self.engine) as session: db_item = model(**kwargs) session.add(db_item) session.commit() session.refresh(db_item) return db_item return create_method_with_model_signature(create_method, model) setattr(Client, f"create_{lower_model_name}", create_create_method(model)) # Generate update method def create_update_method(model): def update_method(self: Client, **kwargs) -> Optional[model]: with Session(self.engine) as session: statement = select(model).where( *create_expression_from_kwargs(model, **kwargs) ) result = session.exec(statement) db_item = result.one_or_none() if db_item: for key, value in kwargs.items(): setattr(db_item, key, value) session.commit() session.refresh(db_item) return db_item return create_method_with_model_signature(update_method, model) setattr(Client, f"update_{lower_model_name}", create_update_method(model)) # Generate delete method def create_delete_method(model): def delete_method(self: Client, id: int) -> bool: with Session(self.engine) as session: statement = select(model).where(model.id == id) result = session.exec(statement) db_item = result.one_or_none() if db_item: session.delete(db_item) session.commit() return True return False return create_method_with_model_signature(delete_method, model) setattr(Client, f"delete_{lower_model_name}", create_delete_method(model))