Skip to content

Instantly share code, notes, and snippets.

@AivanF
Created March 7, 2025 04:30
Show Gist options
  • Save AivanF/4616d7e89fb5ea497deda31201db2c60 to your computer and use it in GitHub Desktop.
Save AivanF/4616d7e89fb5ea497deda31201db2c60 to your computer and use it in GitHub Desktop.
FastAPI / Starlette: Async WebSocket test client
import asyncio
import json
import typing
from starlette.testclient import TestClient, WebSocketTestSession
from starlette.types import Message
class AsyncWebSocketTestSession(WebSocketTestSession):
receive_timeout = None
raise_timeout = True
async def asend(self, message: Message) -> None:
ccft = self.portal.start_task_soon(self._receive_tx.send, message)
await asyncio.wrap_future(ccft)
async def asend_text(self, data: str) -> None:
await self.asend({"type": "websocket.receive", "text": data})
async def asend_bytes(self, data: bytes) -> None:
await self.asend({"type": "websocket.receive", "bytes": data})
async def asend_json(
self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text"
) -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
await self.asend({"type": "websocket.receive", "text": text})
else:
await self.asend({"type": "websocket.receive", "bytes": text.encode("utf-8")})
async def areceive(
self, receive_timeout: int | None = None, raise_timeout: bool | None = None
) -> Message:
ccft = self.portal.start_task_soon(self._send_rx.receive)
raise_timeout = raise_timeout if raise_timeout is not None else self.raise_timeout
try:
return await asyncio.wait_for(
asyncio.wrap_future(ccft),
timeout=receive_timeout or self.receive_timeout,
)
except asyncio.exceptions.TimeoutError:
if raise_timeout:
raise
else:
return None
async def areceive_text(
self, receive_timeout: int | None = None, raise_timeout: bool | None = None
) -> str:
message = await self.areceive(
receive_timeout=receive_timeout, raise_timeout=raise_timeout
)
if message is not None:
self._raise_on_close(message)
return typing.cast(str, message["text"])
return None
async def areceive_bytes(
self,
receive_timeout: int | None = None,
raise_timeout: bool | None = None,
) -> bytes:
message = await self.areceive(
receive_timeout=receive_timeout, raise_timeout=raise_timeout
)
if message is not None:
self._raise_on_close(message)
return typing.cast(bytes, message["bytes"])
return None
async def areceive_json(
self,
mode: typing.Literal["text", "binary"] = "text",
receive_timeout: int | None = None,
raise_timeout: bool | None = None,
) -> typing.Any:
message = await self.areceive(
receive_timeout=receive_timeout, raise_timeout=raise_timeout
)
if message is not None:
self._raise_on_close(message)
if mode == "text":
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
return json.loads(text)
return None
def make_websocket_async(
websocket: WebSocketTestSession,
receive_timeout: int | None = None,
raise_timeout: bool = True,
) -> AsyncWebSocketTestSession:
websocket.__class__ = AsyncWebSocketTestSession
websocket.receive_timeout = receive_timeout
websocket.raise_timeout = raise_timeout
return websocket
def async_websocket_connect(
client: TestClient,
*args,
receive_timeout: int | None = None,
raise_timeout: bool = True,
**kwargs,
) -> AsyncWebSocketTestSession:
return make_websocket_async(client.websocket_connect(*args, **kwargs))
import asyncio
from fastapi.testclient import TestClient
from starlette.testclient import TestClient
from asynctestws import async_websocket_connect
async def test_async_ws(
client: TestClient,
):
with async_websocket_connect(
client,
"/api/ws/connect",
receive_timeout=3,
) as websocket:
await websocket.asend_json(
{
"type": "answer",
"secret": 42,
}
)
response = await websocket.areceive_json()
assert response["reaction"] == "prosperity"
# @pytest.mark.timeout(10) # May be a good thing to use
async def test_async_ws_loop_nones(
client: TestClient,
):
with async_websocket_connect(
client,
"/api/ws/connect",
receive_timeout=3,
raise_timeout=False, # Return None in case of timeout
) as websocket:
await websocket.asend_json(
{
"type": "subscribe",
"topic": "1337",
}
)
done = 0
failed = 0
while done < 10:
response = await websocket.areceive_json()
if response is None:
failed += 1
continue
assert "value" in response
done += 1
assert done > failed
# @pytest.mark.timeout(10) # May be a good thing to use
async def test_async_ws_loop_raise(
client: TestClient,
):
with async_websocket_connect(
client,
"/api/ws/connect",
receive_timeout=3,
raise_timeout=True,
) as websocket:
await websocket.asend_json(
{
"type": "subscribe",
"topic": "1337",
}
)
done = 0
failed = 0
while done < 10:
try:
response = await websocket.areceive_json()
assert "value" in response
done += 1
except asyncio.exceptions.TimeoutError:
failed += 1
continue
assert done > failed
@AivanF
Copy link
Author

AivanF commented Mar 7, 2025

Why?

Currently, WebSocketTestSession misses async methods, which is crucial for many test cases, because, in contrast to mostly idempotent HTTP requests and their linear tests, WebSockets are usually used in asynchronous, concurrent environment, when you don't know and cannot rely on specific order nor number of responses. And I really needed such feature in my project with complex usage scenarios to test.

Story

The major problem is how to asynchronously wait for a response, if this is possible at all. And wait several times, i.e. leaving no broken state on the backend after first timeout error. I dived into source code of the Starlette TestClient, which uses httpx TestClient and anyio.from_thread.BlockingPortal, so I also researched them, experimented, started writing custom TestClient, TestClientTransport and Async/NonBlockingPortal classes... but luckily found BlockingPortal.start_task_soon method that returns a concurrent.Future which can be easily converted to an asyncio.Future and then awaited, so the final solution is quite simple and works fine in my project πŸ™‚

Another problem was the way of integration:

  • Classic approach of subclassing leads to subclassing and overriding of many TestClient methods – which is a waste of code.
  • Dynamic addition of new methods to the WebSocketTestSession – breaks bad type hints.
  • Monkey patching to change result of TestClient.websocket_connect() – bad typing too.
  • Replacing original WebSocketTestSession object's class to my subclass with a robust method signature – seems the best as for me.
  • Custom version of Starlette framework. But my project is quite small to maintain forks. Although I also plan to make a Pull Request to the Starlette repo a bit later.

And I was thinking about naming async versions of send and receive methods, considered options:

  • async_send / async_receive – too verbose.
  • asend / areceive – the most concise and recognisable.
  • submit / expect – brief but maybe confusing.

I attached some usage examples, although my project's cases are much more complicated (I even made kind of WS scenario testing framework; maybe should publish it too?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment