Created
February 20, 2025 22:48
-
-
Save xtea/74647ae07880fd4d9e4c2c03e2e2492f to your computer and use it in GitHub Desktop.
Problem: Distributed Rate Limiter
You are designing a distributed rate limiter that ensures an
API client does not exceed a maximum number of requests in a given time window. The system operates across multiple nodes, meaning the rate-limiting logic cannot rely on local memory alone-it must be consistent across all nodes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import time | |
import http.server | |
import socketserver | |
import requests | |
from urllib.parse import urlparse, parse_qs | |
import hashlib | |
import bisect | |
############################################################################### | |
# CONFIGURABLE PARAMETERS | |
############################################################################### | |
MAX_REQUESTS = 5 # Max requests per user per window | |
WINDOW_SECONDS = 60 # Time window in seconds | |
############################################################################### | |
# CONSISTENT HASHING IMPLEMENTATION | |
############################################################################### | |
class ConsistentHashRing: | |
""" | |
Very basic consistent hash ring. We assume: | |
- 'nodes' is a list of strings like "host:port" | |
- 'num_replicas' controls the number of virtual nodes per real node | |
""" | |
def __init__(self, nodes, num_replicas=100): | |
self.num_replicas = num_replicas | |
self.ring = [] # list of (hash_value, node) tuples | |
self.sorted_keys = [] # sorted list of hash_values | |
for node in nodes: | |
for i in range(self.num_replicas): | |
key = f"{node}-{i}" | |
h = self._hash_str(key) | |
self.ring.append((h, node)) | |
self.ring.sort(key=lambda x: x[0]) | |
self.sorted_keys = [r[0] for r in self.ring] | |
def _hash_str(self, key): | |
"""Returns a consistent integer hash for the given key string.""" | |
return int(hashlib.md5(key.encode("utf-8")).hexdigest(), 16) | |
def get_node(self, key): | |
""" | |
Given a key (e.g., user_id), finds which node owns that key in the ring. | |
""" | |
h = self._hash_str(key) | |
idx = bisect.bisect_left(self.sorted_keys, h) | |
if idx == len(self.sorted_keys): | |
idx = 0 | |
return self.ring[idx][1] | |
############################################################################### | |
# RATE-LIMIT LOGIC (in-memory) -- each node only stores counters for its keys | |
############################################################################### | |
# user_counters: dict[user_id] -> (window_start_timestamp, request_count) | |
user_counters = {} | |
def current_window_start(): | |
""" | |
Returns the start timestamp of the current time window | |
based on WINDOW_SECONDS. | |
""" | |
now = int(time.time()) | |
return now - (now % WINDOW_SECONDS) | |
def check_and_increment(user_id): | |
""" | |
Checks if the user is within the rate limit. | |
If allowed, increments the user's counter. | |
Returns True if allowed, False otherwise. | |
""" | |
global user_counters | |
window_start = current_window_start() | |
if user_id not in user_counters: | |
user_counters[user_id] = (window_start, 1) | |
return True | |
stored_window, count = user_counters[user_id] | |
if stored_window != window_start: | |
# Window rolled over; reset | |
user_counters[user_id] = (window_start, 1) | |
return True | |
else: | |
# Same window | |
if count < MAX_REQUESTS: | |
user_counters[user_id] = (stored_window, count + 1) | |
return True | |
else: | |
return False | |
############################################################################### | |
# HTTP SERVER IMPLEMENTATION | |
############################################################################### | |
class RateLimitHandler(http.server.BaseHTTPRequestHandler): | |
""" | |
Handles GET /check?user_id=xyz | |
1. Determine which node is responsible via consistent hashing. | |
2. If local node is responsible, check local counters. | |
3. Otherwise, forward to the appropriate node. | |
""" | |
def do_GET(self): | |
parsed = urlparse(self.path) | |
if parsed.path != "/check": | |
self.send_error(404, "Not Found") | |
return | |
query = parse_qs(parsed.query) | |
user_id = query.get("user_id", [""])[0] | |
if not user_id: | |
self.send_error(400, "Missing user_id param") | |
return | |
# Which node is responsible for this user_id? | |
responsible_node = ring.get_node(user_id) | |
if responsible_node == local_node: | |
# Handle it locally | |
allowed = check_and_increment(user_id) | |
self._respond(allowed) | |
else: | |
# Forward to the responsible node | |
self._forward_request(user_id, responsible_node) | |
def _forward_request(self, user_id, node): | |
""" | |
Forward the request to the node that owns the user_id in the hash ring. | |
'node' is something like "localhost:5001". | |
""" | |
url = f"http://{node}/check?user_id={user_id}" | |
try: | |
resp = requests.get(url, timeout=5) | |
self.send_response(resp.status_code) | |
for k, v in resp.headers.items(): | |
if k.lower() not in ["content-length", "transfer-encoding", "connection"]: | |
self.send_header(k, v) | |
self.end_headers() | |
self.wfile.write(resp.content) | |
except Exception as e: | |
# If forwarding fails (node is down?), we can't do much else | |
self.send_error(503, f"Node {node} not reachable: {str(e)}") | |
def _respond(self, allowed): | |
if allowed: | |
self.send_response(200) | |
self.send_header("Content-type", "text/plain") | |
self.end_headers() | |
self.wfile.write(b"ALLOWED") | |
else: | |
self.send_response(429) | |
self.send_header("Content-type", "text/plain") | |
self.end_headers() | |
self.wfile.write(b"RATE LIMIT EXCEEDED") | |
def run_server(host, port): | |
with socketserver.TCPServer((host, port), RateLimitHandler) as httpd: | |
print(f"Serving on {host}:{port} (local_node={local_node})") | |
httpd.serve_forever() | |
############################################################################### | |
# MAIN ENTRY POINT | |
############################################################################### | |
if __name__ == "__main__": | |
""" | |
Usage: | |
python consistent_hash_rate_limiter.py <local_port> <node1> <node2> ... <nodeN> | |
Example: | |
# Start Node A on port 5000 | |
python consistent_hash_rate_limiter.py 5000 localhost:5000 localhost:5001 | |
# Start Node B on port 5001 | |
python consistent_hash_rate_limiter.py 5001 localhost:5000 localhost:5001 | |
This means each node's cluster membership is [localhost:5000, localhost:5001]. | |
The first argument is the local node's port; all subsequent arguments | |
are the full list of nodes (including itself). | |
""" | |
if len(sys.argv) < 3: | |
print("Usage: python consistent_hash_rate_limiter.py <local_port> <node1> <node2> ...") | |
sys.exit(1) | |
local_port = sys.argv[1] | |
all_nodes = sys.argv[2:] # e.g. ["localhost:5000", "localhost:5001"] | |
# Construct the local_node string from local_port | |
local_node = f"localhost:{local_port}" | |
# Build the ring | |
ring = ConsistentHashRing(all_nodes, num_replicas=100) | |
# Run the server | |
run_server("0.0.0.0", int(local_port)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment