Created
January 5, 2024 06:41
-
-
Save danielomiya/c1f2758dc96f8f1f3ddeccf3bc493ea0 to your computer and use it in GitHub Desktop.
ASGI wrapper for Cloud Functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| This is an example of how to wrap an ASGI application to run it as a Cloud | |
| Function. Since I did it purely for academic purpose, I'm not quite sure of | |
| how well it performs, but as someone simply trying to run my FastAPI code | |
| inside the GCP serverless platform, it did great. | |
| Also, credits to jordaneremieff/mangum, I got a lot of inspiration from their | |
| work that has enabled running ASGI in AWS Lambda too. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from enum import Enum, auto | |
| from io import BytesIO | |
| from typing import TYPE_CHECKING, Any, Awaitable, Callable, MutableMapping | |
| from flask import make_response | |
| from app import app # a FastAPI instance | |
| if TYPE_CHECKING: | |
| from flask import Request, Response | |
| Scope = MutableMapping[str, Any] | |
| Message = MutableMapping[str, Any] | |
| Receive = Callable[[], Awaitable[Message]] | |
| Send = Callable[[Message], Awaitable[None]] | |
| ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] | |
| class ResponseState(Enum): | |
| WAITING = auto() | |
| STARTED = auto() | |
| COMPLETE = auto() | |
| class LoggingMixin(object): | |
| _logger: logging.Logger | None = None | |
| @property | |
| def logger(self) -> logging.Logger: | |
| if not self._logger: | |
| self._logger = logging.getLogger( | |
| f"{self.__module__}.{self.__class__.__name__}" | |
| ) | |
| return self._logger | |
| class ResponseBuilder(LoggingMixin): | |
| def __init__(self, scope: Scope, body: bytes) -> None: | |
| self.scope = scope | |
| self.state = ResponseState.WAITING | |
| self.status_code: int | None = None | |
| self.response_body: bytes | None = None | |
| self.headers: list[tuple[bytes, bytes]] | None = None | |
| self._buffer = BytesIO() | |
| self._queue: asyncio.Queue[Message] = asyncio.Queue() | |
| self._queue.put_nowait( | |
| { | |
| "type": "http.request", | |
| "body": body, | |
| "more_body": False, | |
| } | |
| ) | |
| async def build(self, app: ASGIApp) -> Response: | |
| try: | |
| await app(self.scope, self.receive, self.send) | |
| except BaseException: | |
| if self.state is ResponseState.WAITING: | |
| await self.send( | |
| { | |
| "type": "http.response.start", | |
| "status": 500, | |
| "headers": [ | |
| (b"content-type", b"text/plain; charset=utf-8") | |
| ], | |
| } | |
| ) | |
| await self.send( | |
| { | |
| "type": "http.response.body", | |
| "body": b"Internal Server Error", | |
| "more_body": False, | |
| } | |
| ) | |
| elif self.state is not ResponseState.COMPLETE: | |
| self.status_code = 500 | |
| self.response_body = b"Internal Server Error" | |
| self.headers = [ | |
| (b"content-type", b"text/plain; charset=utf-8") | |
| ] | |
| return make_response( | |
| self.response_body, | |
| self.status_code, | |
| [(k.decode(), v.decode()) for k, v in self.headers], | |
| ) | |
| async def send(self, message: Message) -> None: | |
| if ( | |
| self.state is ResponseState.WAITING | |
| and message["type"] == "http.response.start" | |
| ): | |
| self.status_code = message["status"] | |
| self.headers = message.get("headers", []) | |
| self.state = ResponseState.STARTED | |
| elif ( | |
| self.state is ResponseState.STARTED | |
| and message["type"] == "http.response.body" | |
| ): | |
| body = message.get("body", b"") | |
| more_body = message.get("more_body", False) | |
| self._buffer.write(body) | |
| if not more_body: | |
| self.response_body = self._buffer.getvalue() | |
| self._buffer.close() | |
| self.state = ResponseState.COMPLETE | |
| await self._queue.put({"type": "http.disconnect"}) | |
| self.logger.info( | |
| "%s %s %s", | |
| self.scope["method"], | |
| self.scope["path"], | |
| self.status_code, | |
| ) | |
| else: | |
| raise RuntimeError(f"Unexpected {message['type']}") | |
| async def receive(self) -> Message: | |
| return await self._queue.get() | |
| def asgi_wrap(asgi: ASGIApp) -> Callable[[Request], Response]: | |
| def handle(request: Request) -> Response: | |
| environ = request.environ | |
| scope = { | |
| "type": "http", | |
| "method": request.method, | |
| "http_version": "1.1", | |
| "headers": [ | |
| (k.encode(), v.encode()) for k, v in request.headers.items() | |
| ], | |
| "path": request.path, | |
| "raw_path": None, | |
| "root_path": "", | |
| "scheme": request.scheme, | |
| "query_string": request.query_string, | |
| "server": (environ["SERVER_NAME"], environ["SERVER_PORT"]), | |
| "client": environ["REMOTE_ADDR"], | |
| } | |
| request_body = request.data or b"" | |
| response_builder = ResponseBuilder(scope, request_body) | |
| return asyncio.run(response_builder.build(app)) | |
| return handle | |
| handler = asgi_wrap(app) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment