Skip to content

Instantly share code, notes, and snippets.

@ckcollab
Last active September 14, 2024 19:40
Show Gist options
  • Save ckcollab/cd1788fb95933f2065c4b66f21f572a6 to your computer and use it in GitHub Desktop.
Save ckcollab/cd1788fb95933f2065c4b66f21f572a6 to your computer and use it in GitHub Desktop.
LastWriteTimestampRouter
import logging
import time
from utils.middleware import get_last_write_cookie_timestamp
logger = logging.getLogger(__name__)
class LastWriteTimestampRouter:
def db_for_read(self, model, **hints):
current_time = int(time.time())
last_write_timestamp = get_last_write_cookie_timestamp()
logger.debug(f"Last write timestamp: {last_write_timestamp}")
if last_write_timestamp:
if (current_time - last_write_timestamp) < 30:
logger.debug("We wrote to datbase recently, use primary")
return 'default'
else:
logger.debug("We haven't written recently, use replica")
return 'replica'
logger.debug("No recent write data at all?? go to default database, may be important!")
return 'default'
def db_for_write(self, model, **hints):
return 'default' # Always write to the default database
def allow_relation(self, obj1, obj2, **hints):
return True
def allow_migrate(self, db, app_label, model_name=None, **hints):
return True
from contextvars import ContextVar
from channels.middleware import BaseMiddleware
LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR = ContextVar('last_write_timestamp', default=None)
WRITE_METHODS = {'POST', 'PUT', 'PATCH', 'DELETE'}
class LastWriteCookieTimestampMiddleware(BaseMiddleware):
async def __call__(self, scope, receive, send):
if scope['type'] != 'http':
return await super().__call__(scope, receive, send)
method = scope['method']
if method in WRITE_METHODS:
current_timestamp = int(time.time())
token = LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.set(current_timestamp)
else:
headers = dict(scope['headers'])
cookie_header = headers.get(b'cookie', b'').decode()
cookies = {cookie.split('=')[0].strip(): cookie.split('=')[1].strip()
for cookie in cookie_header.split(';') if '=' in cookie}
last_write_timestamp = cookies.get('last_write_timestamp')
if last_write_timestamp and last_write_timestamp.isdigit():
token = LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.set(int(last_write_timestamp))
else:
token = LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.set(None)
async def wrapped_send(event):
if event['type'] == 'http.response.start':
headers = event.get('headers', [])
if method in WRITE_METHODS:
cookie_value = f'last_write_timestamp={current_timestamp}; HttpOnly; Path=/'
headers.append((b'Set-Cookie', cookie_value.encode()))
# If the user has no cookie, we should still set it to "1" (truthy, old as
# hell) so they hit the replica servers
elif not last_write_timestamp:
cookie_value = f'last_write_timestamp=1; HttpOnly; Path=/'
headers.append((b'Set-Cookie', cookie_value.encode()))
event['headers'] = headers
await send(event)
# Call the inner application, and send the response, but ensure we reset the context var
# so that it doesn't leak into other requests.
try:
await super().__call__(scope, receive, wrapped_send)
finally:
LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.reset(token)
def get_last_write_cookie_timestamp():
return LAST_WRITE_COOKIE_TIMESTAMP_CONTEXT_VAR.get()
import json
import time
import pytest
from channels.testing import HttpCommunicator
from django.urls import reverse
from factories import UserFactory
from tests.utils import CkcAPITestCase
from asgi import application
from utils.middleware import get_last_write_cookie_timestamp
class TestLastWriteTimestampRouter(CkcAPITestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.user = UserFactory()
def setUp(self):
self.client.force_authenticate(self.user)
self.session_key = self.client.session.session_key
async def make_request(self, method, url, data=None):
headers = [
(b"cookie", f"sessionid={self.session_key}".encode()),
]
body = bytes()
if method in ['POST', 'PUT', 'PATCH', 'DELETE']:
headers.append((b"content-type", b"application/json"))
body = json.dumps(data).encode() if data else b''
communicator = HttpCommunicator(
application,
method,
url,
headers=headers,
body=body,
)
return await communicator.get_response()
@pytest.mark.asyncio
async def test_middleware_doesnt_set_timestamp_cookie_on_read_calls(self):
url = reverse('user-me')
response = await self.make_request("GET", url)
#assert response["status"] == 200 # TODO actually gets 401,, lol.. but cookie is set..
# Make sure we set a default timestamp of 1 on read calls, if no timestamp found
assert any(value.startswith(b'last_write_timestamp=1')
for key, value in response['headers'])
# Ensure context var is reset between requests
assert get_last_write_cookie_timestamp() is None
@pytest.mark.asyncio
async def test_middleware_sets_timestamp_cookie_on_write_calls(self):
url = reverse("user-detail", args=[self.user.id])
payload = {
"tos_agree": True,
"data_consent": True,
"date_of_birth": "04/12/1990",
}
response = await self.make_request("PATCH", url, payload)
#assert response["status"] == 200 # TODO actually gets 401,, lol.. but cookie is set..
# Did we get a valid cookie value back?
timestamp_cookie_value = None
for key, value in response['headers']:
if value.startswith(b'last_write_timestamp='):
# Grab the value here "last_write_timestamp=<value>"
timestamp_cookie_value = int(value.decode().split('=')[1].split(';')[0])
break
assert timestamp_cookie_value >= int(time.time()) - 5 # Should be within 5 seconds of now
# Ensure context var is reset between requests
assert get_last_write_cookie_timestamp() is None
import time
import pytest
from django.urls import reverse
from unittest.mock import patch
from factories import UserFactory, ForumFactory, BoardFactory
from tests.utils import CkcAPITestCase
from utils.database import LastWriteTimestampRouter
@pytest.mark.django_db(transaction=True, databases=['default', 'replica'])
@patch('utils.database.get_last_write_cookie_timestamp')
def test_board_listing_uses_replica_then_primary_after_write(mock_get_last_write_timestamp, client, mocker):
user = UserFactory()
forum = ForumFactory(is_official=True)
BoardFactory(forum=forum)
# Set up a cookie timestamp from 31 seconds ago
mock_get_last_write_timestamp.return_value = int(time.time()) - 31
spy = mocker.spy(LastWriteTimestampRouter, 'db_for_read')
# Do a query, it should go to the replica database
client.force_login(user)
url = reverse('boards-list')
response = client.get(url)
assert response.status_code == 200
assert spy.spy_return == 'replica'
# Mock a write operation 5 seconds ago
mock_get_last_write_timestamp.return_value = int(time.time()) - 5
# Do a query, it should go to the primary database
response = client.get(url)
assert response.status_code == 200
assert spy.spy_return == 'default'
@pytest.mark.django_db(transaction=True, databases=['default', 'replica'])
@patch('utils.database.get_last_write_cookie_timestamp')
def test_thread_creation_uses_primary(mock_get_last_write_timestamp, client, mocker):
user = UserFactory()
forum = ForumFactory(is_official=True)
board = BoardFactory(forum=forum)
# Set up a cookie timestamp from 31 seconds ago, so this has a value isntead of being MagicMock
mock_get_last_write_timestamp.return_value = int(time.time()) - 31
client.force_login(user)
url = reverse('threads-list')
data = {
'board': board.id,
'title': 'Test Thread',
'content': 'This is a test thread.',
}
resp = client.post(url, data, format='json')
assert resp.status_code == 201
# Since we can't get ASGI middleware working in these tests, manually set new last write timestammp
mock_get_last_write_timestamp.return_value = int(time.time())
# Spy on the database router
spy = mocker.spy(LastWriteTimestampRouter, 'db_for_read')
# If we try to read the boards list immediately after the write, it should use the primary database
url = reverse('boards-list')
resp = client.get(url)
assert resp.status_code == 200
assert spy.spy_return == 'default'
# 35 seconds later, the read should use the replica database
mock_get_last_write_timestamp.return_value = int(time.time()) - 35
resp = client.get(url)
assert resp.status_code == 200
assert spy.spy_return == 'replica'
class TestLastWriteTimestampRouter(CkcAPITestCase):
databases = {'default', 'replica'}
def setUp(self):
self.router = LastWriteTimestampRouter()
@patch('utils.database.get_last_write_cookie_timestamp')
def test_database_router_logic(self, mock_get_last_write_timestamp):
# Simulate a request with no recent write
mock_get_last_write_timestamp.return_value = None
assert self.router.db_for_read(None) == 'default'
# Simulate a write operation
current_time = int(time.time())
mock_get_last_write_timestamp.return_value = current_time
assert self.router.db_for_write(None) == 'default'
# Simulate a read immediately after write
assert self.router.db_for_read(None) == 'default'
# Simulate a read after 31 seconds
mock_get_last_write_timestamp.return_value = current_time - 31
assert self.router.db_for_read(None) == 'replica'
def test_db_for_write_always_default(self):
assert self.router.db_for_write(None) == 'default'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment