Created
March 7, 2025 04:30
-
-
Save AivanF/4616d7e89fb5ea497deda31201db2c60 to your computer and use it in GitHub Desktop.
FastAPI / Starlette: Async WebSocket test client
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import json | |
import typing | |
from starlette.testclient import TestClient, WebSocketTestSession | |
from starlette.types import Message | |
class AsyncWebSocketTestSession(WebSocketTestSession): | |
receive_timeout = None | |
raise_timeout = True | |
async def asend(self, message: Message) -> None: | |
ccft = self.portal.start_task_soon(self._receive_tx.send, message) | |
await asyncio.wrap_future(ccft) | |
async def asend_text(self, data: str) -> None: | |
await self.asend({"type": "websocket.receive", "text": data}) | |
async def asend_bytes(self, data: bytes) -> None: | |
await self.asend({"type": "websocket.receive", "bytes": data}) | |
async def asend_json( | |
self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text" | |
) -> None: | |
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) | |
if mode == "text": | |
await self.asend({"type": "websocket.receive", "text": text}) | |
else: | |
await self.asend({"type": "websocket.receive", "bytes": text.encode("utf-8")}) | |
async def areceive( | |
self, receive_timeout: int | None = None, raise_timeout: bool | None = None | |
) -> Message: | |
ccft = self.portal.start_task_soon(self._send_rx.receive) | |
raise_timeout = raise_timeout if raise_timeout is not None else self.raise_timeout | |
try: | |
return await asyncio.wait_for( | |
asyncio.wrap_future(ccft), | |
timeout=receive_timeout or self.receive_timeout, | |
) | |
except asyncio.exceptions.TimeoutError: | |
if raise_timeout: | |
raise | |
else: | |
return None | |
async def areceive_text( | |
self, receive_timeout: int | None = None, raise_timeout: bool | None = None | |
) -> str: | |
message = await self.areceive( | |
receive_timeout=receive_timeout, raise_timeout=raise_timeout | |
) | |
if message is not None: | |
self._raise_on_close(message) | |
return typing.cast(str, message["text"]) | |
return None | |
async def areceive_bytes( | |
self, | |
receive_timeout: int | None = None, | |
raise_timeout: bool | None = None, | |
) -> bytes: | |
message = await self.areceive( | |
receive_timeout=receive_timeout, raise_timeout=raise_timeout | |
) | |
if message is not None: | |
self._raise_on_close(message) | |
return typing.cast(bytes, message["bytes"]) | |
return None | |
async def areceive_json( | |
self, | |
mode: typing.Literal["text", "binary"] = "text", | |
receive_timeout: int | None = None, | |
raise_timeout: bool | None = None, | |
) -> typing.Any: | |
message = await self.areceive( | |
receive_timeout=receive_timeout, raise_timeout=raise_timeout | |
) | |
if message is not None: | |
self._raise_on_close(message) | |
if mode == "text": | |
text = message["text"] | |
else: | |
text = message["bytes"].decode("utf-8") | |
return json.loads(text) | |
return None | |
def make_websocket_async( | |
websocket: WebSocketTestSession, | |
receive_timeout: int | None = None, | |
raise_timeout: bool = True, | |
) -> AsyncWebSocketTestSession: | |
websocket.__class__ = AsyncWebSocketTestSession | |
websocket.receive_timeout = receive_timeout | |
websocket.raise_timeout = raise_timeout | |
return websocket | |
def async_websocket_connect( | |
client: TestClient, | |
*args, | |
receive_timeout: int | None = None, | |
raise_timeout: bool = True, | |
**kwargs, | |
) -> AsyncWebSocketTestSession: | |
return make_websocket_async(client.websocket_connect(*args, **kwargs)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from fastapi.testclient import TestClient | |
from starlette.testclient import TestClient | |
from asynctestws import async_websocket_connect | |
async def test_async_ws( | |
client: TestClient, | |
): | |
with async_websocket_connect( | |
client, | |
"/api/ws/connect", | |
receive_timeout=3, | |
) as websocket: | |
await websocket.asend_json( | |
{ | |
"type": "answer", | |
"secret": 42, | |
} | |
) | |
response = await websocket.areceive_json() | |
assert response["reaction"] == "prosperity" | |
# @pytest.mark.timeout(10) # May be a good thing to use | |
async def test_async_ws_loop_nones( | |
client: TestClient, | |
): | |
with async_websocket_connect( | |
client, | |
"/api/ws/connect", | |
receive_timeout=3, | |
raise_timeout=False, # Return None in case of timeout | |
) as websocket: | |
await websocket.asend_json( | |
{ | |
"type": "subscribe", | |
"topic": "1337", | |
} | |
) | |
done = 0 | |
failed = 0 | |
while done < 10: | |
response = await websocket.areceive_json() | |
if response is None: | |
failed += 1 | |
continue | |
assert "value" in response | |
done += 1 | |
assert done > failed | |
# @pytest.mark.timeout(10) # May be a good thing to use | |
async def test_async_ws_loop_raise( | |
client: TestClient, | |
): | |
with async_websocket_connect( | |
client, | |
"/api/ws/connect", | |
receive_timeout=3, | |
raise_timeout=True, | |
) as websocket: | |
await websocket.asend_json( | |
{ | |
"type": "subscribe", | |
"topic": "1337", | |
} | |
) | |
done = 0 | |
failed = 0 | |
while done < 10: | |
try: | |
response = await websocket.areceive_json() | |
assert "value" in response | |
done += 1 | |
except asyncio.exceptions.TimeoutError: | |
failed += 1 | |
continue | |
assert done > failed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Why?
Currently, WebSocketTestSession misses async methods, which is crucial for many test cases, because, in contrast to mostly idempotent HTTP requests and their linear tests, WebSockets are usually used in asynchronous, concurrent environment, when you don't know and cannot rely on specific order nor number of responses. And I really needed such feature in my project with complex usage scenarios to test.
Story
The major problem is how to asynchronously wait for a response, if this is possible at all. And wait several times, i.e. leaving no broken state on the backend after first timeout error. I dived into source code of the Starlette TestClient, which uses httpx TestClient and anyio.from_thread.BlockingPortal, so I also researched them, experimented, started writing custom TestClient, TestClientTransport and Async/NonBlockingPortal classes... but luckily found
BlockingPortal.start_task_soon
method that returns aconcurrent.Future
which can be easily converted to anasyncio.Future
and then awaited, so the final solution is quite simple and works fine in my project πAnother problem was the way of integration:
TestClient.websocket_connect()
β bad typing too.And I was thinking about naming async versions of
send
andreceive
methods, considered options:async_send
/async_receive
β too verbose.asend
/areceive
β the most concise and recognisable.submit
/expect
β brief but maybe confusing.I attached some usage examples, although my project's cases are much more complicated (I even made kind of WS scenario testing framework; maybe should publish it too?).