Last active
October 22, 2023 09:59
-
-
Save lemon24/64704ced769c5723a75ad64b5d023883 to your computer and use it in GitHub Desktop.
Distributed key-value store prototype, with no kind of consistency.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Distributed key-value store prototype, with no kind of consistency. | |
--- | |
A demo (from before we had bootstapping): https://asciinema.org/a/616231 | |
On one machine: | |
>>> d = dkv.DKV() | |
>>> d.start() | |
>>> | |
>>> d.get('one', timeout=1) # wait 1s for a response from others | |
>>> d.set('one', b'111') | |
>>> d.get('one') | |
b'111' | |
On a second machine: | |
>>> d = dkv.DKV() | |
>>> d.start() | |
>>> | |
>>> d.get('one') # retrieve value from the first machine | |
b'111' | |
>>> d.set('two', b'222') | |
Back on the first machine: | |
>>> d.get('two', timeout=0) # value already received from the second machine | |
b'222' | |
--- | |
The current network architecture is as follows: | |
* each node has a SUB socket to receive messages from all other nodes | |
* each node has a PUB socket to send messages to all other nodes | |
.--------------------------. | |
| v | |
.-----------. .-----------. | |
| PUB | SUB |<-------| PUB | SUB | | |
'-----------' '-----------' | |
| ^ .----------' ^ | |
| '---------. | | |
| v | | | |
| .-----------. | | |
'----->| SUB | PUB |-------' | |
'-----------' | |
This is wildly inefficient; for example, when A asks for a key, | |
all its peers respond to all their peers, not only to A. | |
A has a subscription filter for messages intended for itself, | |
so this doesn't need to be handled in code, | |
but the network traffic still happens underneath. | |
Note the ZeroMQ book already has a [shared key-value store] example, | |
but I wanted to see if I can cobble together something on my own. | |
--- | |
Discovery works via IPv4 local network UDP broadcast, | |
and works even if the nodes move networks / change IPs. | |
This is needed because while ZeroMQ (known) peers can come and go, | |
there's no way of discovering them. | |
This is based on same idea as [zbeacon], | |
but cobbled together independently from StackOverflow examples | |
(zbeacon comes from the C binding, and does not exist in the Python one). | |
Also see the section on [discovery] in the ZeroMQ book. | |
--- | |
Bootstrapping starts whenever a node gets network connectivity back. | |
The node requests a list of keys from all other peers, | |
then requests keys one by one with a small delay. | |
--- | |
[zbeacon]: http://api.zeromq.org/czmq1-4:zbeacon | |
[discovery]: https://zguide.zeromq.org/docs/chapter8/#Discovery | |
[shared key-value store]: https://zguide.zeromq.org/docs/chapter5/#Reliable-Pub-Sub-Clone-Pattern | |
""" | |
import threading | |
import random | |
import socket | |
import queue | |
import errno | |
import time | |
import zmq | |
import sys | |
from dataclasses import dataclass, field | |
from functools import wraps, partial | |
ERRNO_NET = {errno.ENETUNREACH, errno.EADDRNOTAVAIL, errno.ENETDOWN} | |
class Beacon: | |
"""IPv4 local network broadcast beacon. | |
Send payload to local network nodes every second. | |
Call event((address, payload)) with messages from other nodes. | |
""" | |
interval = 1 | |
prefix = b'beacon ' | |
port = 5005 | |
def __init__(self, payload, event): | |
self.payload = payload | |
self.event = event | |
self.address = None | |
self.sender = threading.Thread(target=self._send, daemon=True) | |
self.receiver = threading.Thread(target=self._receive, daemon=True) | |
self.done = False | |
def start(self): | |
self.sender.start() | |
self.receiver.start() | |
def shutdown(self): | |
self.done = True | |
self.sender.join() | |
self.receiver.join() | |
# https://stackoverflow.com/a/64067297 + comment by Tamir Adler | |
def _send(self): | |
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) | |
message = self.prefix + self.payload | |
while not self.done: | |
try: | |
s.sendto(message, ('255.255.255.255', self.port)) | |
except OSError as e: | |
if e.errno not in ERRNO_NET: | |
raise | |
if self.address: | |
self.event(('net_down', e)) | |
self.address = None | |
time.sleep(self.interval) | |
def _receive(self): | |
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) | |
# only needed when running more than one per host | |
# https://gist.github.com/Crtrpt/616eaae1ec00810c1d04474f188bcebd | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
s.bind(('0.0.0.0', self.port)) | |
s.settimeout(self.interval) | |
while not self.done: | |
if not self.address: | |
try: | |
self.address = address = get_local_address() | |
self.event(('net_up', address)) | |
except OSError as e: | |
if e.errno not in ERRNO_NET: | |
raise | |
try: | |
data, (address, _) = s.recvfrom(1024) | |
except TimeoutError: | |
continue | |
if not self.address: | |
continue | |
if not data.startswith(self.prefix): | |
continue | |
payload = data.removeprefix(self.prefix) | |
if address == self.address and payload == self.payload: | |
continue | |
self.event(('ping', address, payload)) | |
def get_local_address(): | |
# https://stackoverflow.com/a/166589 | |
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: | |
s.connect(('8.8.8.8', 80)) | |
return s.getsockname()[0] | |
class Tracker: | |
"""IPv4 local network presence tracker. | |
Broadcast own presence using a Beacon. | |
Call event((True, peer)) on the first heartbeat from a peer. | |
Call event((False, peer)) after timeout seconds without a heartbeat. | |
""" | |
timeout = 4 | |
def __init__(self, payload, event): | |
self.event = event | |
self.peers = {} | |
self.queue = queue.Queue() | |
self.beacon = Beacon(payload, self.queue.put) | |
self.worker = threading.Thread(target=self._worker, daemon=True) | |
self.done = False | |
def start(self): | |
self.beacon.start() | |
self.worker.start() | |
def shutdown(self): | |
self.done = True | |
self.beacon.shutdown() | |
self.worker.join() | |
def _worker(self): | |
while not self.done: | |
now = time.monotonic() | |
for peer, last_seen in list(self.peers.items()): | |
if now - last_seen > self.timeout: | |
self.event(('peer_down', *peer)) | |
del self.peers[peer] | |
try: | |
event = self.queue.get(timeout=self.beacon.interval) | |
except queue.Empty: | |
continue | |
if event[0] != 'ping': | |
self.event(event) | |
continue | |
peer = event[1:] | |
if peer not in self.peers: | |
self.event(('peer_up', *peer)) | |
self.peers[peer] = now | |
for peer in self.peers: | |
self.event(('peer_down', *peer)) | |
class DKV: | |
def __init__(self, log=None): | |
self.log = log or (lambda *_: None) | |
self.data = {} | |
self.worker = threading.Thread(target=self._worker, daemon=True) | |
self.done = False | |
# https://pyzmq.readthedocs.io/en/latest/howto/morethanbindings.html#thread-safety | |
# after __init__, we either lock around socket method calls, | |
# or only use a socket from a single thread | |
self.ctx = ctx = zmq.Context() | |
# socket for sending messages to other peers; | |
# used from anywhere, so it needs a lock | |
self.pub = ctx.socket(zmq.PUB) | |
self.port = self.pub.bind_to_random_port('tcp://*') | |
self.pub_lock = threading.Lock() | |
# broadcast pub's port to other nodes; | |
# send presence changes from others to an internal socket | |
sender = ctx.socket(zmq.PAIR) | |
sender.bind("inproc://events") | |
self.tracker = Tracker(str(self.port).encode(), sender.send_pyobj) | |
# ...so we can receive them in the worker thread | |
self.tracker_events = ctx.socket(zmq.PAIR) | |
self.tracker_events.connect("inproc://events") | |
# socket for receiving messages from other peers; | |
# only used from the worker thread | |
self.sub = ctx.socket(zmq.SUB) | |
# we care about value updates and questions from anyone | |
self.sub.subscribe(b'set') | |
self.sub.subscribe(b'question') | |
# but only about answers intended for us | |
self.id = random.randbytes(8) | |
self.sub.subscribe(b'answer ' + self.id) | |
# one threading.Event for each key we've asked about, | |
# so get() can wait for an answer | |
self.pending_questions = {} | |
self.bootstrap = None | |
self.sub.subscribe(b'list') | |
self.sub.subscribe(b'list_answer ' + self.id) | |
def start(self): | |
self.tracker.start() | |
self.worker.start() | |
def shutdown(self): | |
self.done = True | |
self.tracker.shutdown() | |
self.worker.join() | |
self.ctx.destroy() | |
def _worker(self): | |
self.log(f"publishing on tcp://*:{self.port} with id {self.id.hex()}") | |
poller = zmq.Poller() | |
poller.register(self.sub, zmq.POLLIN) | |
poller.register(self.tracker_events, zmq.POLLIN) | |
while not self.done: | |
for sock, _ in poller.poll(100): | |
if sock is self.tracker_events: | |
name, *args = sock.recv_pyobj() | |
self.log("tracker:", name, *args) | |
elif sock is self.sub: | |
name, *args = sock.recv_multipart() | |
self.log("sub received:", name, *args) | |
name, _, id = name.partition(b' ') | |
name = name.decode() | |
if id: | |
args = (id, *args) | |
else: | |
assert False | |
try: | |
meth = getattr(self, f'_handle_{name}') | |
except AttributeError: | |
self.log("UNHANDLED!", name) | |
else: | |
meth(*args) | |
if self.bootstrap: | |
if self.bootstrap.done: | |
self.bootstrap = None | |
else: | |
self.bootstrap.step() | |
def _handle_peer_up(self, ip, port_bytes): | |
address = f"tcp://{ip}:{port_bytes.decode()}" | |
self.sub.connect(address) | |
def _handle_peer_down(self, ip, port_bytes): | |
address = f"tcp://{ip}:{port_bytes.decode()}" | |
self.sub.disconnect(address) | |
def _handle_set(self, key, value): | |
self.data[key.decode()] = value | |
def _request_question(self, key): | |
self._pub_send_multipart((b'question ' + self.id, key.encode())) | |
def _handle_question(self, id, key): | |
value = self.data.get(key.decode()) | |
if not value: | |
return | |
self._pub_send_multipart((b'answer ' + id, key, value)) | |
def _handle_answer(self, id, key, value): | |
assert id == self.id | |
self.data[key.decode()] = value | |
# notify any waiting get() calls that we have an answer | |
have_answer = self.pending_questions.pop(key.decode(), None) | |
if have_answer: | |
have_answer.set() | |
def _handle_net_up(self, ip): | |
self.bootstrap = BootstrapState( | |
self._request_list, | |
self._request_question, | |
# give peers time to connect | |
self.tracker.beacon.interval * 2, | |
log=self.log, | |
) | |
def _request_list(self): | |
self._pub_send_multipart((b'list ' + self.id,)) | |
def _handle_list(self, id): | |
message = list(map(str.encode, self.data)) | |
message.insert(0, b'list_answer ' + id) | |
self._pub_send_multipart(message) | |
def _handle_list_answer(self, id, *keys): | |
assert id == self.id | |
if not self.bootstrap: | |
return | |
self.bootstrap.handle_list(map(bytes.decode, keys)) | |
def _pub_send_multipart(self, message): | |
with self.pub_lock: | |
self.pub.send_multipart(message) | |
self.log('pub sent:', *message) | |
def set(self, key, value): | |
self.data[key] = value | |
# tell everyone about the new value | |
self._pub_send_multipart((b'set', key.encode(), value)) | |
def get(self, key, *, timeout=.1): | |
# if we have it, return it | |
value = self.data.get(key) | |
if value is not None: | |
return value | |
# if we don't have it, ask others, | |
# and wait `timeout` seconds for an answer | |
# setdefault() is likely atomic | |
# https://bugs.python.org/issue13521 | |
# https://mail.python.org/pipermail/python-list/2018-July/885957.html | |
have_answer = self.pending_questions.setdefault(key, threading.Event()) | |
self._request_question(key) | |
if have_answer.wait(timeout): | |
self.log(f"get({key!r}): got answer") | |
else: | |
self.log(f"get({key!r}): timed out") | |
# if someone answered, the value is already in self.data | |
return self.data.get(key) | |
@dataclass | |
class BootstrapState: | |
request_list: callable | |
request_key: callable | |
list_delay: float | |
key_delay: float = 0.005 | |
list_after: float = 0 | |
key_after: float = 0 | |
done_after: float = 0 | |
done: bool = False | |
remaining_keys: set = field(default_factory=set) | |
done_keys: set = field(default_factory=set) | |
log: callable = lambda *_: None | |
time = time.monotonic | |
def __post_init__(self): | |
self.list_after = self.time() + self.list_delay | |
self.log(f"bootstrap: waiting {self.list_delay:.1f}s before starting") | |
def step(self): | |
now = self.time() | |
if self.list_after and self.list_after <= now: | |
self.request_list() | |
self.list_after = 0 | |
self.key_after = now + self.key_delay | |
self.log(f"bootstrap: started") | |
elif self.key_after and self.key_after <= now: | |
if self.remaining_keys: | |
key = self.remaining_keys.pop() | |
self.request_key(key) | |
self.done_keys.add(key) | |
if self.remaining_keys: | |
self.key_after = now + self.key_delay | |
else: | |
self.key_after = 0 | |
# should use a separate delay for this, but eh | |
self.done_after = now + self.list_delay | |
self.log(f"bootstrap: no keys remaining, waiting another {self.list_delay:.1f}s") | |
elif self.done_after and self.done_after <= now: | |
self.done = True | |
self.log(f"bootstrap: done") | |
def handle_list(self, keys): | |
self.remaining_keys.update(k for k in keys if k not in self.done_keys) | |
self.key_after = self.time() + self.key_delay | |
if __name__ == '__main__': | |
import random | |
dkv = DKV(print) | |
dkv.start() | |
dkv.set(f'one-{dkv.port}', b'111') | |
time.sleep(1 + random.random()) | |
if random.random() < 0.5: | |
dkv.set('hello', f"from {dkv.port}".encode()) | |
time.sleep(random.random()) | |
print(time.time()) | |
print(dkv.get('hello')) | |
print(time.time()) | |
time.sleep(100) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment