Skip to content

Instantly share code, notes, and snippets.

@xtea
Created February 20, 2025 22:48
Show Gist options
  • Save xtea/74647ae07880fd4d9e4c2c03e2e2492f to your computer and use it in GitHub Desktop.
Save xtea/74647ae07880fd4d9e4c2c03e2e2492f to your computer and use it in GitHub Desktop.
Problem: Distributed Rate Limiter You are designing a distributed rate limiter that ensures an API client does not exceed a maximum number of requests in a given time window. The system operates across multiple nodes, meaning the rate-limiting logic cannot rely on local memory alone-it must be consistent across all nodes.
import sys
import time
import http.server
import socketserver
import requests
from urllib.parse import urlparse, parse_qs
import hashlib
import bisect
###############################################################################
# CONFIGURABLE PARAMETERS
###############################################################################
MAX_REQUESTS = 5 # Max requests per user per window
WINDOW_SECONDS = 60 # Time window in seconds
###############################################################################
# CONSISTENT HASHING IMPLEMENTATION
###############################################################################
class ConsistentHashRing:
"""
Very basic consistent hash ring. We assume:
- 'nodes' is a list of strings like "host:port"
- 'num_replicas' controls the number of virtual nodes per real node
"""
def __init__(self, nodes, num_replicas=100):
self.num_replicas = num_replicas
self.ring = [] # list of (hash_value, node) tuples
self.sorted_keys = [] # sorted list of hash_values
for node in nodes:
for i in range(self.num_replicas):
key = f"{node}-{i}"
h = self._hash_str(key)
self.ring.append((h, node))
self.ring.sort(key=lambda x: x[0])
self.sorted_keys = [r[0] for r in self.ring]
def _hash_str(self, key):
"""Returns a consistent integer hash for the given key string."""
return int(hashlib.md5(key.encode("utf-8")).hexdigest(), 16)
def get_node(self, key):
"""
Given a key (e.g., user_id), finds which node owns that key in the ring.
"""
h = self._hash_str(key)
idx = bisect.bisect_left(self.sorted_keys, h)
if idx == len(self.sorted_keys):
idx = 0
return self.ring[idx][1]
###############################################################################
# RATE-LIMIT LOGIC (in-memory) -- each node only stores counters for its keys
###############################################################################
# user_counters: dict[user_id] -> (window_start_timestamp, request_count)
user_counters = {}
def current_window_start():
"""
Returns the start timestamp of the current time window
based on WINDOW_SECONDS.
"""
now = int(time.time())
return now - (now % WINDOW_SECONDS)
def check_and_increment(user_id):
"""
Checks if the user is within the rate limit.
If allowed, increments the user's counter.
Returns True if allowed, False otherwise.
"""
global user_counters
window_start = current_window_start()
if user_id not in user_counters:
user_counters[user_id] = (window_start, 1)
return True
stored_window, count = user_counters[user_id]
if stored_window != window_start:
# Window rolled over; reset
user_counters[user_id] = (window_start, 1)
return True
else:
# Same window
if count < MAX_REQUESTS:
user_counters[user_id] = (stored_window, count + 1)
return True
else:
return False
###############################################################################
# HTTP SERVER IMPLEMENTATION
###############################################################################
class RateLimitHandler(http.server.BaseHTTPRequestHandler):
"""
Handles GET /check?user_id=xyz
1. Determine which node is responsible via consistent hashing.
2. If local node is responsible, check local counters.
3. Otherwise, forward to the appropriate node.
"""
def do_GET(self):
parsed = urlparse(self.path)
if parsed.path != "/check":
self.send_error(404, "Not Found")
return
query = parse_qs(parsed.query)
user_id = query.get("user_id", [""])[0]
if not user_id:
self.send_error(400, "Missing user_id param")
return
# Which node is responsible for this user_id?
responsible_node = ring.get_node(user_id)
if responsible_node == local_node:
# Handle it locally
allowed = check_and_increment(user_id)
self._respond(allowed)
else:
# Forward to the responsible node
self._forward_request(user_id, responsible_node)
def _forward_request(self, user_id, node):
"""
Forward the request to the node that owns the user_id in the hash ring.
'node' is something like "localhost:5001".
"""
url = f"http://{node}/check?user_id={user_id}"
try:
resp = requests.get(url, timeout=5)
self.send_response(resp.status_code)
for k, v in resp.headers.items():
if k.lower() not in ["content-length", "transfer-encoding", "connection"]:
self.send_header(k, v)
self.end_headers()
self.wfile.write(resp.content)
except Exception as e:
# If forwarding fails (node is down?), we can't do much else
self.send_error(503, f"Node {node} not reachable: {str(e)}")
def _respond(self, allowed):
if allowed:
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ALLOWED")
else:
self.send_response(429)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"RATE LIMIT EXCEEDED")
def run_server(host, port):
with socketserver.TCPServer((host, port), RateLimitHandler) as httpd:
print(f"Serving on {host}:{port} (local_node={local_node})")
httpd.serve_forever()
###############################################################################
# MAIN ENTRY POINT
###############################################################################
if __name__ == "__main__":
"""
Usage:
python consistent_hash_rate_limiter.py <local_port> <node1> <node2> ... <nodeN>
Example:
# Start Node A on port 5000
python consistent_hash_rate_limiter.py 5000 localhost:5000 localhost:5001
# Start Node B on port 5001
python consistent_hash_rate_limiter.py 5001 localhost:5000 localhost:5001
This means each node's cluster membership is [localhost:5000, localhost:5001].
The first argument is the local node's port; all subsequent arguments
are the full list of nodes (including itself).
"""
if len(sys.argv) < 3:
print("Usage: python consistent_hash_rate_limiter.py <local_port> <node1> <node2> ...")
sys.exit(1)
local_port = sys.argv[1]
all_nodes = sys.argv[2:] # e.g. ["localhost:5000", "localhost:5001"]
# Construct the local_node string from local_port
local_node = f"localhost:{local_port}"
# Build the ring
ring = ConsistentHashRing(all_nodes, num_replicas=100)
# Run the server
run_server("0.0.0.0", int(local_port))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment