Last active
September 14, 2024 19:40
-
-
Save ckcollab/cd1788fb95933f2065c4b66f21f572a6 to your computer and use it in GitHub Desktop.
LastWriteTimestampRouter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import time | |
from utils.middleware import get_last_write_cookie_timestamp | |
logger = logging.getLogger(__name__) | |
class LastWriteTimestampRouter: | |
def db_for_read(self, model, **hints): | |
current_time = int(time.time()) | |
last_write_timestamp = get_last_write_cookie_timestamp() | |
logger.debug(f"Last write timestamp: {last_write_timestamp}") | |
if last_write_timestamp: | |
if (current_time - last_write_timestamp) < 30: | |
logger.debug("We wrote to datbase recently, use primary") | |
return 'default' | |
else: | |
logger.debug("We haven't written recently, use replica") | |
return 'replica' | |
logger.debug("No recent write data at all?? go to default database, may be important!") | |
return 'default' | |
def db_for_write(self, model, **hints): | |
return 'default' # Always write to the default database | |
def allow_relation(self, obj1, obj2, **hints): | |
return True | |
def allow_migrate(self, db, app_label, model_name=None, **hints): | |
return True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextvars import ContextVar | |
from channels.middleware import BaseMiddleware | |
LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR = ContextVar('last_write_timestamp', default=None) | |
WRITE_METHODS = {'POST', 'PUT', 'PATCH', 'DELETE'} | |
class LastWriteCookieTimestampMiddleware(BaseMiddleware): | |
async def __call__(self, scope, receive, send): | |
if scope['type'] != 'http': | |
return await super().__call__(scope, receive, send) | |
method = scope['method'] | |
if method in WRITE_METHODS: | |
current_timestamp = int(time.time()) | |
token = LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.set(current_timestamp) | |
else: | |
headers = dict(scope['headers']) | |
cookie_header = headers.get(b'cookie', b'').decode() | |
cookies = {cookie.split('=')[0].strip(): cookie.split('=')[1].strip() | |
for cookie in cookie_header.split(';') if '=' in cookie} | |
last_write_timestamp = cookies.get('last_write_timestamp') | |
if last_write_timestamp and last_write_timestamp.isdigit(): | |
token = LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.set(int(last_write_timestamp)) | |
else: | |
token = LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.set(None) | |
async def wrapped_send(event): | |
if event['type'] == 'http.response.start': | |
headers = event.get('headers', []) | |
if method in WRITE_METHODS: | |
cookie_value = f'last_write_timestamp={current_timestamp}; HttpOnly; Path=/' | |
headers.append((b'Set-Cookie', cookie_value.encode())) | |
# If the user has no cookie, we should still set it to "1" (truthy, old as | |
# hell) so they hit the replica servers | |
elif not last_write_timestamp: | |
cookie_value = f'last_write_timestamp=1; HttpOnly; Path=/' | |
headers.append((b'Set-Cookie', cookie_value.encode())) | |
event['headers'] = headers | |
await send(event) | |
# Call the inner application, and send the response, but ensure we reset the context var | |
# so that it doesn't leak into other requests. | |
try: | |
await super().__call__(scope, receive, wrapped_send) | |
finally: | |
LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.reset(token) | |
def get_last_write_cookie_timestamp(): | |
return LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.get() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import time | |
import pytest | |
from channels.testing import HttpCommunicator | |
from django.urls import reverse | |
from factories import UserFactory | |
from tests.utils import CkcAPITestCase | |
from asgi import application | |
from utils.middleware import get_last_write_cookie_timestamp | |
class TestLastWriteTimestampRouter(CkcAPITestCase): | |
@classmethod | |
def setUpClass(cls): | |
super().setUpClass() | |
cls.user = UserFactory() | |
def setUp(self): | |
self.client.force_authenticate(self.user) | |
self.session_key = self.client.session.session_key | |
async def make_request(self, method, url, data=None): | |
headers = [ | |
(b"cookie", f"sessionid={self.session_key}".encode()), | |
] | |
body = bytes() | |
if method in ['POST', 'PUT', 'PATCH', 'DELETE']: | |
headers.append((b"content-type", b"application/json")) | |
body = json.dumps(data).encode() if data else b'' | |
communicator = HttpCommunicator( | |
application, | |
method, | |
url, | |
headers=headers, | |
body=body, | |
) | |
return await communicator.get_response() | |
@pytest.mark.asyncio | |
async def test_middleware_doesnt_set_timestamp_cookie_on_read_calls(self): | |
url = reverse('user-me') | |
response = await self.make_request("GET", url) | |
#assert response["status"] == 200 # TODO actually gets 401,, lol.. but cookie is set.. | |
# Make sure we set a default timestamp of 1 on read calls, if no timestamp found | |
assert any(value.startswith(b'last_write_timestamp=1') | |
for key, value in response['headers']) | |
# Ensure context var is reset between requests | |
assert get_last_write_cookie_timestamp() is None | |
@pytest.mark.asyncio | |
async def test_middleware_sets_timestamp_cookie_on_write_calls(self): | |
url = reverse("user-detail", args=[self.user.id]) | |
payload = { | |
"tos_agree": True, | |
"data_consent": True, | |
"date_of_birth": "04/12/1990", | |
} | |
response = await self.make_request("PATCH", url, payload) | |
#assert response["status"] == 200 # TODO actually gets 401,, lol.. but cookie is set.. | |
# Did we get a valid cookie value back? | |
timestamp_cookie_value = None | |
for key, value in response['headers']: | |
if value.startswith(b'last_write_timestamp='): | |
# Grab the value here "last_write_timestamp=<value>" | |
timestamp_cookie_value = int(value.decode().split('=')[1].split(';')[0]) | |
break | |
assert timestamp_cookie_value >= int(time.time()) - 5 # Should be within 5 seconds of now | |
# Ensure context var is reset between requests | |
assert get_last_write_cookie_timestamp() is None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import pytest | |
from django.urls import reverse | |
from unittest.mock import patch | |
from factories import UserFactory, ForumFactory, BoardFactory | |
from tests.utils import CkcAPITestCase | |
from utils.database import LastWriteTimestampRouter | |
@pytest.mark.django_db(transaction=True, databases=['default', 'replica']) | |
@patch('utils.database.get_last_write_cookie_timestamp') | |
def test_board_listing_uses_replica_then_primary_after_write(mock_get_last_write_timestamp, client, mocker): | |
user = UserFactory() | |
forum = ForumFactory(is_official=True) | |
BoardFactory(forum=forum) | |
# Set up a cookie timestamp from 31 seconds ago | |
mock_get_last_write_timestamp.return_value = int(time.time()) - 31 | |
spy = mocker.spy(LastWriteTimestampRouter, 'db_for_read') | |
# Do a query, it should go to the replica database | |
client.force_login(user) | |
url = reverse('boards-list') | |
response = client.get(url) | |
assert response.status_code == 200 | |
assert spy.spy_return == 'replica' | |
# Mock a write operation 5 seconds ago | |
mock_get_last_write_timestamp.return_value = int(time.time()) - 5 | |
# Do a query, it should go to the primary database | |
response = client.get(url) | |
assert response.status_code == 200 | |
assert spy.spy_return == 'default' | |
@pytest.mark.django_db(transaction=True, databases=['default', 'replica']) | |
@patch('utils.database.get_last_write_cookie_timestamp') | |
def test_thread_creation_uses_primary(mock_get_last_write_timestamp, client, mocker): | |
user = UserFactory() | |
forum = ForumFactory(is_official=True) | |
board = BoardFactory(forum=forum) | |
# Set up a cookie timestamp from 31 seconds ago, so this has a value isntead of being MagicMock | |
mock_get_last_write_timestamp.return_value = int(time.time()) - 31 | |
client.force_login(user) | |
url = reverse('threads-list') | |
data = { | |
'board': board.id, | |
'title': 'Test Thread', | |
'content': 'This is a test thread.', | |
} | |
resp = client.post(url, data, format='json') | |
assert resp.status_code == 201 | |
# Since we can't get ASGI middleware working in these tests, manually set new last write timestammp | |
mock_get_last_write_timestamp.return_value = int(time.time()) | |
# Spy on the database router | |
spy = mocker.spy(LastWriteTimestampRouter, 'db_for_read') | |
# If we try to read the boards list immediately after the write, it should use the primary database | |
url = reverse('boards-list') | |
resp = client.get(url) | |
assert resp.status_code == 200 | |
assert spy.spy_return == 'default' | |
# 35 seconds later, the read should use the replica database | |
mock_get_last_write_timestamp.return_value = int(time.time()) - 35 | |
resp = client.get(url) | |
assert resp.status_code == 200 | |
assert spy.spy_return == 'replica' | |
class TestLastWriteTimestampRouter(CkcAPITestCase): | |
databases = {'default', 'replica'} | |
def setUp(self): | |
self.router = LastWriteTimestampRouter() | |
@patch('utils.database.get_last_write_cookie_timestamp') | |
def test_database_router_logic(self, mock_get_last_write_timestamp): | |
# Simulate a request with no recent write | |
mock_get_last_write_timestamp.return_value = None | |
assert self.router.db_for_read(None) == 'default' | |
# Simulate a write operation | |
current_time = int(time.time()) | |
mock_get_last_write_timestamp.return_value = current_time | |
assert self.router.db_for_write(None) == 'default' | |
# Simulate a read immediately after write | |
assert self.router.db_for_read(None) == 'default' | |
# Simulate a read after 31 seconds | |
mock_get_last_write_timestamp.return_value = current_time - 31 | |
assert self.router.db_for_read(None) == 'replica' | |
def test_db_for_write_always_default(self): | |
assert self.router.db_for_write(None) == 'default' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment