Created
October 29, 2024 13:44
-
-
Save jrialland/39db7d32e807f9fceba7ae039123faf4 to your computer and use it in GitHub Desktop.
proxies http and websocket traffic to an upstream server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import requests | |
import logging | |
from urllib.parse import urljoin | |
import asyncio | |
import websockets | |
from http import HTTPStatus | |
from asgiref.typing import ( | |
ASGI3Application, | |
Scope, | |
HTTPScope, | |
WebSocketScope, | |
ASGIReceiveCallable, | |
ASGISendCallable, | |
) | |
from typing import BinaryIO | |
# ------------------------------------------------------------------------------ | |
class BodyToFileLike(BinaryIO): | |
"""Utility class that allow to stream the body of an HTTP request to a file-like object, such as a file on disk. | |
Used to be able to stream large payloads to the upstream server using the requests library. | |
""" | |
def __init__(self, receive: ASGIReceiveCallable, content_length: int): | |
self.receive = receive | |
self.content_length = content_length | |
self.more_body = True | |
self.buffer = b"" | |
def __len__(self): | |
return self.content_length | |
def readall(self) -> bytes: | |
data = b"" | |
while self.more_body: | |
event = self.receive() | |
if event["type"] == "http.request": | |
data += event.get("body", b"") | |
self.more_body = event.get("more_body", False) | |
else: | |
raise RuntimeError("Unexpected event type") | |
return data | |
def read(self, size: int = -1) -> bytes: | |
if size < 0: | |
return self.readall() | |
else: | |
if not self.more_body and len(self.buffer) == 0: | |
raise EOFError("Reached end of body") | |
# first send what we have in the buffer | |
while self.more_body and len(self.buffer) < size: | |
event = self.receive() | |
if event["type"] == "http.request": | |
self.buffer += event.get("body", b"") | |
self.more_body = event.get("more_body", False) | |
else: | |
raise RuntimeError("Unexpected event type") | |
if len(self.buffer) >= size: | |
data = self.buffer[:size] | |
self.buffer = self.buffer[size:] | |
return data | |
else: | |
data = self.buffer | |
self.buffer = b"" | |
return data | |
# ------------------------------------------------------------------------------ | |
class Proxy(ASGI3Application): | |
""" | |
ASGI application that proxies both http and websocket requests to an upstream server. | |
""" | |
def __init__(self, upstream_url: str): | |
self.logger = logging.getLogger(__name__) | |
self.upstream_url = upstream_url | |
async def __call__( | |
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable | |
): | |
if scope["type"] == "http": | |
await self.proxy_http(scope, receive, send) | |
elif scope["type"] == "websocket": | |
await self.proxy_websocket(scope, receive, send) | |
async def proxy_http( | |
self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable | |
): | |
"""Proxy the request to the frontend application.""" | |
headers = {k.decode(): v.decode() for k, v in scope["headers"]} | |
headers["Forwarded"] = f"for={scope['client'][0]}" | |
try: | |
root_path = scope.get("app_root_path", "") | |
url = f"{self.upstream_url}{scope['path'][len(root_path):]}" | |
request_content_length = int(headers.get("content-length", 0)) | |
# perform the request to the upstream server, streaming the response back to the client | |
self.logger.debug(f'{scope["method"]} {url}') | |
response = requests.request( | |
scope["method"], | |
url, | |
headers=headers, | |
data=BodyToFileLike(receive, request_content_length), | |
stream=True, | |
) | |
# send the response status and headers back to the client | |
await send( | |
{ | |
"type": "http.response.start", | |
"status": response.status_code, | |
"headers": [ | |
[k.encode(), v.encode()] for k, v in response.headers.items() | |
], | |
} | |
) | |
# stream the response back to the client | |
for chunk in response.iter_content(chunk_size=1024): | |
await send( | |
{ | |
"type": "http.response.body", | |
"body": chunk, | |
"more_body": True, | |
} | |
) | |
# send the final empty body | |
await send( | |
{ | |
"type": "http.response.body", | |
"body": b"", | |
"more_body": False, | |
} | |
) | |
except requests.ConnectionError as e: | |
self.logger.exception("Error during HTTP proxying") | |
await self.make_simple_response( | |
send, | |
status=HTTPStatus.BAD_GATEWAY, | |
headers={"Content-Type": "text/plain"}, | |
body=b"502 Bad Gateway", | |
) | |
async def make_simple_response( | |
self, | |
send, | |
status: int = 200, | |
headers: dict[str, str] | None = None, | |
body: str | bytes | None = None, | |
): | |
headers = headers or {} | |
headers.setdefault("Content-Length", str(len(body or b""))) | |
await send( | |
{ | |
"type": "http.response.start", | |
"status": status, | |
"headers": [[k.encode(), v.encode()] for k, v in headers.items()], | |
} | |
) | |
await send( | |
{ | |
"type": "http.response.body", | |
"body": body, | |
"more_body": False, | |
} | |
) | |
async def proxy_websocket( | |
self, | |
scope: WebSocketScope, | |
receive: ASGIReceiveCallable, | |
send: ASGISendCallable, | |
): | |
# we should enter this method only when the scope is a websocket scope, handling a websocket connection request | |
message = await receive() | |
assert message["type"] == "websocket.connect" | |
# connect to the upstream server | |
app_root_path = scope.get("app_root_path", "") | |
ws_url = urljoin( | |
re.sub("^http", "ws", self.upstream_url), | |
scope["path"][len(app_root_path) :], | |
) | |
self.logger.debug(f"Connecting to {ws_url}") | |
async with websockets.connect( | |
ws_url, | |
subprotocols=scope["subprotocols"], | |
) as websocket: | |
# accept the connection using the accepted subprotocol if any | |
accept_event = {"type": "websocket.accept"} | |
if websocket.subprotocol: | |
accept_event["subprotocol"] = websocket.subprotocol | |
await send(accept_event) | |
# Create tasks for bidirectional communication | |
client_to_server_task = asyncio.create_task( | |
self.client_to_server(receive, websocket) | |
) | |
server_to_client_task = asyncio.create_task( | |
self.server_to_client(websocket, send) | |
) | |
# start the tasks | |
_, pending = await asyncio.wait( | |
[client_to_server_task, server_to_client_task], | |
return_when=asyncio.FIRST_COMPLETED, | |
) | |
for task in pending: | |
task.cancel() | |
async def client_to_server( | |
self, | |
receive: ASGIReceiveCallable, | |
upstream_websocket: websockets.WebSocketClientProtocol, | |
): | |
"""Forward websocket messages from the client to the server.""" | |
while True: | |
message = await receive() | |
# forward the message to the server | |
if message["type"] == "websocket.receive": | |
data = message.get("bytes") or message.get("text") | |
await upstream_websocket.send(data) | |
elif message["type"] == "websocket.disconnect": | |
await upstream_websocket.close() | |
break | |
else: | |
raise NotImplementedError(f"Unknown message type {message['type']}") | |
async def server_to_client( | |
self, | |
upstream_websocket: websockets.WebSocketClientProtocol, | |
send: ASGISendCallable, | |
): | |
"""Forward websocket messages from the server to the client.""" | |
try: | |
# for each message received from the server, forward it to the client | |
async for data in upstream_websocket: | |
fmt = "text" if isinstance(data, str) else "bytes" | |
# forward the message to the client | |
await send( | |
{ | |
"type": "websocket.send", | |
fmt: data, | |
} | |
) | |
except websockets.ConnectionClosed as e: | |
# close the connection to the client if the server closes the connection | |
await send( | |
{ | |
"type": "websocket.close", | |
"code": e.code, | |
"reason": e.reason, | |
} | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment