from __future__ import annotations

import inspect
from functools import wraps
from typing import List, Optional, Type

from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.ext.asyncio.session import AsyncSession

class BaseClient:
    def __init__(self, connection_string: str):
        self.engine = create_engine(connection_string)
        self.models: list[Type[SQLModel]] = []


class Client(BaseClient):
    def create_db_and_tables(self):
        """Creates the db and all the required tables based on the sqlmodels"""
        SQLModel.metadata.create_all(self.engine)


class AsyncClient(BaseClient):
    def __init__(self, connection_string: str):
        super().__init__(connection_string)
        self.async_engine = create_async_engine(connection_string)

    async def create_db_and_tables(self):
        """Creates the db and all the required tables based on the sqlmodels"""
        async with self.async_engine.begin() as conn:
            await conn.run_sync(SQLModel.metadata.create_all)


def create_method_with_model_signature(func, model: Type[SQLModel]):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        return func(self, *args, **kwargs)

    # Get the original signature
    sig = inspect.signature(func)

    # Get the model's fields
    model_fields = model.__fields__

    # Create new parameters based on the model's fields
    new_params = [
        inspect.Parameter(
            "self",
            inspect.Parameter.POSITIONAL_ONLY,
        )
    ]

    for field_name, field in model_fields.items():
        if field.is_required():
            param = inspect.Parameter(
                field_name,
                inspect.Parameter.KEYWORD_ONLY,
                annotation=field.annotation,
            )

        else:
            param = inspect.Parameter(
                field_name,
                inspect.Parameter.KEYWORD_ONLY,
                annotation=field.annotation,
                default=field.default,
            )
        new_params.append(param)

    wrapper.__signature__ = sig.replace(parameters=new_params)
    return wrapper


def generate_async_client_class(db_models: List[Type[SQLModel]]) -> Type[AsyncClient]:
    for model in db_models:
        create_async_methods(AsyncClient, model)

    return AsyncClient


def generate_sync_client_class(db_models: List[Type[SQLModel]]) -> Type[Client]:
    for model in db_models:
        create_sync_methods(Client, model)

    return Client


def create_expression_from_kwargs(model, **kwargs):
    for key, value in kwargs.items():
        if isinstance(value, tuple):
            yield value[1](model.__dict__[key], value[0])
        elif isinstance(value, list):
            yield model.__dict__[key].in_(value)
        else:
            yield model.__dict__[key] == value


def create_async_methods(AsyncClient: Type[AsyncClient], model: Type[SQLModel]):
    model_name = model.__name__
    lower_model_name = model_name.lower()

    # Generate create method
    def create_create_method(model):
        async def create_method(self: AsyncClient, *args, **kwargs) -> model:
            async with AsyncSession(self.async_engine) as session:
                db_item = model(*args, **kwargs)
                session.add(db_item)
                await session.commit()
                await session.refresh(db_item)
                return db_item

        return create_method_with_model_signature(create_method, model)

    setattr(AsyncClient, f"create_{lower_model_name}", create_create_method(model))

    # Generate get method
    def create_get_method(model):
        async def get_method(self: AsyncClient, **kwargs) -> Optional[model]:
            async with AsyncSession(self.async_engine) as session:
                statement = select(model).where(
                    *create_expression_from_kwargs(model, **kwargs)
                )
                result = await session.exec(statement)
                return result.one()

        return create_method_with_model_signature(get_method, model)

    setattr(AsyncClient, f"get_{lower_model_name}", create_get_method(model))

    # Generate get_all method
    def create_get_all_method(model):
        async def get_all_method(self: AsyncClient) -> List[model]:
            async with AsyncSession(self.async_engine) as session:
                statement = select(model)
                result = await session.exec(statement)
                return result.all()

        return create_method_with_model_signature(get_all_method, model)

    setattr(AsyncClient, f"get_all_{lower_model_name}s", create_get_all_method(model))

    # Generate update method
    def create_update_method(model):
        async def update_method(self: AsyncClient, **kwargs) -> Optional[model]:
            async with AsyncSession(self.async_engine) as session:
                statement = select(model).where(
                    *create_expression_from_kwargs(model, **kwargs)
                )
                result = await session.exec(statement)
                db_item = result.one()
                if db_item:
                    for key, value in kwargs.items():
                        setattr(db_item, key, value)
                    await session.commit()
                    await session.refresh(db_item)
                return db_item

        return create_method_with_model_signature(update_method, model)

    setattr(AsyncClient, f"update_{lower_model_name}", create_update_method(model))

    # Generate delete method
    def create_delete_method(model):
        async def delete_method(self: AsyncClient, **kwargs) -> bool:
            async with AsyncSession(self.async_engine) as session:
                statement = select(model).where(
                    *create_expression_from_kwargs(model, **kwargs)
                )
                result = await session.exec(statement)
                db_item = result.one()
                if db_item:
                    await session.delete(db_item)
                    await session.commit()
                    return True
                return False

        return create_method_with_model_signature(delete_method, model)

    setattr(AsyncClient, f"delete_{lower_model_name}", create_delete_method(model))


def create_sync_methods(Client: Type[Client], model: Type[SQLModel]):
    model_name = model.__name__
    lower_model_name = model_name.lower()

    # Generate get method
    def create_get_method(model):
        def get_method(self: Client, **kwargs) -> Optional[model]:
            with Session(self.engine) as session:
                statement = select(model).where(
                    *create_expression_from_kwargs(model, **kwargs)
                )
                result = session.exec(statement)
                return result.one_or_none()

        return create_method_with_model_signature(get_method, model)

    setattr(Client, f"get_{lower_model_name}", create_get_method(model))

    # Generate get_all method
    def create_get_all_method(model):
        def get_all_method(self: Client) -> List[model]:
            with Session(self.engine) as session:
                statement = select(model)
                result = session.exec(statement).all()
                return result

        return create_method_with_model_signature(get_all_method, model)

    setattr(Client, f"get_all_{lower_model_name}s", create_get_all_method(model))

    # Generate create method
    def create_create_method(model):
        def create_method(self: Client, **kwargs) -> model:
            with Session(self.engine) as session:
                db_item = model(**kwargs)
                session.add(db_item)
                session.commit()
                session.refresh(db_item)
                return db_item

        return create_method_with_model_signature(create_method, model)

    setattr(Client, f"create_{lower_model_name}", create_create_method(model))

    # Generate update method
    def create_update_method(model):
        def update_method(self: Client, **kwargs) -> Optional[model]:
            with Session(self.engine) as session:
                statement = select(model).where(
                    *create_expression_from_kwargs(model, **kwargs)
                )
                result = session.exec(statement)
                db_item = result.one_or_none()
                if db_item:
                    for key, value in kwargs.items():
                        setattr(db_item, key, value)
                    session.commit()
                    session.refresh(db_item)
                return db_item

        return create_method_with_model_signature(update_method, model)

    setattr(Client, f"update_{lower_model_name}", create_update_method(model))

    # Generate delete method
    def create_delete_method(model):
        def delete_method(self: Client, id: int) -> bool:
            with Session(self.engine) as session:
                statement = select(model).where(model.id == id)
                result = session.exec(statement)
                db_item = result.one_or_none()
                if db_item:
                    session.delete(db_item)
                    session.commit()
                    return True
                return False

        return create_method_with_model_signature(delete_method, model)

    setattr(Client, f"delete_{lower_model_name}", create_delete_method(model))