Skip to content

Instantly share code, notes, and snippets.

@jrialland
Created October 29, 2024 13:44
Show Gist options
  • Save jrialland/39db7d32e807f9fceba7ae039123faf4 to your computer and use it in GitHub Desktop.
Save jrialland/39db7d32e807f9fceba7ae039123faf4 to your computer and use it in GitHub Desktop.
proxies http and websocket traffic to an upstream server
import re
import requests
import logging
from urllib.parse import urljoin
import asyncio
import websockets
from http import HTTPStatus
from asgiref.typing import (
ASGI3Application,
Scope,
HTTPScope,
WebSocketScope,
ASGIReceiveCallable,
ASGISendCallable,
)
from typing import BinaryIO
# ------------------------------------------------------------------------------
class BodyToFileLike(BinaryIO):
"""Utility class that allow to stream the body of an HTTP request to a file-like object, such as a file on disk.
Used to be able to stream large payloads to the upstream server using the requests library.
"""
def __init__(self, receive: ASGIReceiveCallable, content_length: int):
self.receive = receive
self.content_length = content_length
self.more_body = True
self.buffer = b""
def __len__(self):
return self.content_length
def readall(self) -> bytes:
data = b""
while self.more_body:
event = self.receive()
if event["type"] == "http.request":
data += event.get("body", b"")
self.more_body = event.get("more_body", False)
else:
raise RuntimeError("Unexpected event type")
return data
def read(self, size: int = -1) -> bytes:
if size < 0:
return self.readall()
else:
if not self.more_body and len(self.buffer) == 0:
raise EOFError("Reached end of body")
# first send what we have in the buffer
while self.more_body and len(self.buffer) < size:
event = self.receive()
if event["type"] == "http.request":
self.buffer += event.get("body", b"")
self.more_body = event.get("more_body", False)
else:
raise RuntimeError("Unexpected event type")
if len(self.buffer) >= size:
data = self.buffer[:size]
self.buffer = self.buffer[size:]
return data
else:
data = self.buffer
self.buffer = b""
return data
# ------------------------------------------------------------------------------
class Proxy(ASGI3Application):
"""
ASGI application that proxies both http and websocket requests to an upstream server.
"""
def __init__(self, upstream_url: str):
self.logger = logging.getLogger(__name__)
self.upstream_url = upstream_url
async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
):
if scope["type"] == "http":
await self.proxy_http(scope, receive, send)
elif scope["type"] == "websocket":
await self.proxy_websocket(scope, receive, send)
async def proxy_http(
self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable
):
"""Proxy the request to the frontend application."""
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
headers["Forwarded"] = f"for={scope['client'][0]}"
try:
root_path = scope.get("app_root_path", "")
url = f"{self.upstream_url}{scope['path'][len(root_path):]}"
request_content_length = int(headers.get("content-length", 0))
# perform the request to the upstream server, streaming the response back to the client
self.logger.debug(f'{scope["method"]} {url}')
response = requests.request(
scope["method"],
url,
headers=headers,
data=BodyToFileLike(receive, request_content_length),
stream=True,
)
# send the response status and headers back to the client
await send(
{
"type": "http.response.start",
"status": response.status_code,
"headers": [
[k.encode(), v.encode()] for k, v in response.headers.items()
],
}
)
# stream the response back to the client
for chunk in response.iter_content(chunk_size=1024):
await send(
{
"type": "http.response.body",
"body": chunk,
"more_body": True,
}
)
# send the final empty body
await send(
{
"type": "http.response.body",
"body": b"",
"more_body": False,
}
)
except requests.ConnectionError as e:
self.logger.exception("Error during HTTP proxying")
await self.make_simple_response(
send,
status=HTTPStatus.BAD_GATEWAY,
headers={"Content-Type": "text/plain"},
body=b"502 Bad Gateway",
)
async def make_simple_response(
self,
send,
status: int = 200,
headers: dict[str, str] | None = None,
body: str | bytes | None = None,
):
headers = headers or {}
headers.setdefault("Content-Length", str(len(body or b"")))
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
}
)
await send(
{
"type": "http.response.body",
"body": body,
"more_body": False,
}
)
async def proxy_websocket(
self,
scope: WebSocketScope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
):
# we should enter this method only when the scope is a websocket scope, handling a websocket connection request
message = await receive()
assert message["type"] == "websocket.connect"
# connect to the upstream server
app_root_path = scope.get("app_root_path", "")
ws_url = urljoin(
re.sub("^http", "ws", self.upstream_url),
scope["path"][len(app_root_path) :],
)
self.logger.debug(f"Connecting to {ws_url}")
async with websockets.connect(
ws_url,
subprotocols=scope["subprotocols"],
) as websocket:
# accept the connection using the accepted subprotocol if any
accept_event = {"type": "websocket.accept"}
if websocket.subprotocol:
accept_event["subprotocol"] = websocket.subprotocol
await send(accept_event)
# Create tasks for bidirectional communication
client_to_server_task = asyncio.create_task(
self.client_to_server(receive, websocket)
)
server_to_client_task = asyncio.create_task(
self.server_to_client(websocket, send)
)
# start the tasks
_, pending = await asyncio.wait(
[client_to_server_task, server_to_client_task],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
async def client_to_server(
self,
receive: ASGIReceiveCallable,
upstream_websocket: websockets.WebSocketClientProtocol,
):
"""Forward websocket messages from the client to the server."""
while True:
message = await receive()
# forward the message to the server
if message["type"] == "websocket.receive":
data = message.get("bytes") or message.get("text")
await upstream_websocket.send(data)
elif message["type"] == "websocket.disconnect":
await upstream_websocket.close()
break
else:
raise NotImplementedError(f"Unknown message type {message['type']}")
async def server_to_client(
self,
upstream_websocket: websockets.WebSocketClientProtocol,
send: ASGISendCallable,
):
"""Forward websocket messages from the server to the client."""
try:
# for each message received from the server, forward it to the client
async for data in upstream_websocket:
fmt = "text" if isinstance(data, str) else "bytes"
# forward the message to the client
await send(
{
"type": "websocket.send",
fmt: data,
}
)
except websockets.ConnectionClosed as e:
# close the connection to the client if the server closes the connection
await send(
{
"type": "websocket.close",
"code": e.code,
"reason": e.reason,
}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment