Last active
March 27, 2025 19:40
-
-
Save dmontagu/87e9d3d7795b14b63388d4b16054f0ff to your computer and use it in GitHub Desktop.
FastAPI CBV
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints | |
from fastapi import APIRouter, Depends | |
from pydantic.typing import is_classvar | |
from starlette.routing import Route, WebSocketRoute | |
T = TypeVar("T") | |
CBV_CLASS_KEY = "__cbv_class__" | |
def cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]: | |
def decorator(cls: Type[T]) -> Type[T]: | |
return _cbv(router, cls) | |
return decorator | |
def _cbv(router: APIRouter, cls: Type[T]) -> Type[T]: | |
_init_cbv(cls) | |
cbv_router = APIRouter() | |
functions = inspect.getmembers(cls, inspect.isfunction) | |
routes_by_endpoint = { | |
route.endpoint: route for route in router.routes if isinstance(route, (Route, WebSocketRoute)) | |
} | |
for _, func in functions: | |
route = routes_by_endpoint.get(func) | |
if route is None: | |
continue | |
router.routes.remove(route) | |
_update_cbv_route_endpoint_signature(cls, route) | |
cbv_router.routes.append(route) | |
router.include_router(cbv_router) | |
return cls | |
def _update_cbv_route_endpoint_signature(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None: | |
old_endpoint = route.endpoint | |
old_signature = inspect.signature(old_endpoint) | |
old_parameters: List[inspect.Parameter] = list(old_signature.parameters.values()) | |
old_first_parameter = old_parameters[0] | |
new_first_parameter = old_first_parameter.replace(default=Depends(cls)) | |
new_parameters = [new_first_parameter] + [ | |
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:] | |
] | |
new_signature = old_signature.replace(parameters=new_parameters) | |
setattr(route.endpoint, "__signature__", new_signature) | |
def _init_cbv(cls: Type[Any]) -> None: | |
if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover | |
return # Already initialized | |
old_init: Callable[..., Any] = cls.__init__ | |
old_signature = inspect.signature(old_init) | |
old_parameters = list(old_signature.parameters.values())[1:] # drop `self` parameter | |
new_parameters = [ | |
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) | |
] | |
dependency_names: List[str] = [] | |
for name, hint in get_type_hints(cls).items(): | |
if is_classvar(hint): | |
continue | |
parameter_kwargs = {} | |
parameter_kwargs["default"] = getattr(cls, name, Ellipsis) | |
dependency_names.append(name) | |
new_parameters.append( | |
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs) | |
) | |
new_signature = old_signature.replace(parameters=new_parameters) | |
def new_init(self: Any, *args: Any, **kwargs: Any) -> None: | |
for dep_name in dependency_names: | |
dep_value = kwargs.pop(dep_name) | |
setattr(self, dep_name, dep_value) | |
old_init(self, *args, **kwargs) | |
setattr(cls, "__signature__", new_signature) | |
setattr(cls, "__init__", new_init) | |
setattr(cls, CBV_CLASS_KEY, True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import APIRouter, Depends, FastAPI | |
from starlette.testclient import TestClient | |
from fastapi_cbv import cbv | |
router = APIRouter() | |
def dependency() -> int: | |
return 1 | |
@cbv(router) | |
class CBV: | |
x: int = Depends(dependency) | |
def __init__(self, z: int = Depends(dependency)): | |
self.y = 1 | |
self.z = z | |
@router.get("/", response_model=int) | |
def f(self) -> int: | |
return self.x + self.y + self.z | |
app = FastAPI() | |
app.include_router(router) | |
client = TestClient(app) | |
assert client.get("/").content == b"3" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment