Created
August 1, 2023 23:39
-
-
Save dejanceltra/ed4778d691448c23acaf7e42cdfc4446 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simple KV store accessible via TCP. | |
Supports two operations: | |
- get(key): fetching value (or None if not found) | |
- set(key, value): setting value | |
Server is single-threaded asyncio-based, which simplifies implementation, | |
and probably multi-threaded would be slower, since it would need to share `_cache` dictionary between threads. | |
Both sync and async clients are provided. | |
""" | |
import asyncio | |
from enum import Enum | |
import socket | |
from typing import Awaitable, Dict, Optional | |
import aiorwlock | |
class Commands(Enum): | |
GET = b'g' | |
SET = b's' | |
QUIT = b'q' | |
NOT_FOUND = b'n' | |
OK = b'o' | |
class SyncClient: | |
def __init__(self): | |
self._socket: Optional[socket.socket] = None | |
def connect(self, ip: str, port: int) -> Awaitable[None]: | |
self._socket = socket.socket() | |
self._socket.connect((ip, port)) | |
def get(self, key: str) -> Awaitable[Optional[str]]: | |
if not self._socket: | |
raise Exception('client not connected') | |
if not isinstance(key, str): | |
raise Exception(f'key must be of type str, received: {str(type(key))}') | |
self._socket.sendall(Commands.GET.value) | |
length = int.to_bytes(len(key), 4, 'big') | |
self._socket.sendall(length) | |
self._socket.sendall(key.encode()) | |
data = self._readexactly(1) | |
if data == Commands.NOT_FOUND.value: | |
return None | |
if data != Commands.OK.value: | |
raise Exception(f'unknown response: {data}') | |
length = int.from_bytes(self._readexactly(4), 'big') | |
return self._readexactly(length) | |
def set(self, key: str, value: str) -> Awaitable[None]: | |
if not self._socket: | |
raise Exception('client not connected') | |
if not isinstance(key, str): | |
raise Exception(f'key must be of type str, received: {str(type(key))}') | |
if not isinstance(value, str): | |
raise Exception(f'value must be of type str, received: {str(type(value))}') | |
self._socket.sendall(Commands.SET.value) | |
length = int.to_bytes(len(key), 4, 'big') | |
self._socket.sendall(length) | |
self._socket.sendall(key.encode()) | |
value = value.encode() | |
length = int.to_bytes(len(value), 4, 'big') | |
self._socket.sendall(length) | |
self._socket.sendall(value) | |
data = self._readexactly(1) | |
if data != Commands.OK.value: | |
raise Exception(f'unknown response: {data}') | |
def close(self) -> Awaitable[None]: | |
self._socket.sendall(Commands.QUIT.value) | |
self._socket.close() | |
def _readexactly(self, length: int) -> bytes: | |
data = b'' | |
received = 0 | |
while True: | |
remaining = length - received | |
part = self._socket.recv(remaining) | |
data += part | |
received += len(part) | |
if received == length: | |
break | |
if len(part) == 0: | |
raise Exception(f'stream reading failed; expected {length} bytes, got {received} bytes') | |
return data | |
class AsyncClient: | |
def __init__(self): | |
self._reader: Optional[asyncio.StreamReader] = None | |
self._writer: Optional[asyncio.StreamWriter] = None | |
async def connect(self, ip: str, port: int) -> Awaitable[None]: | |
self._reader, self._writer = await asyncio.open_connection(ip, port) | |
async def get(self, key: str) -> Awaitable[Optional[str]]: | |
self._writer.write(Commands.GET.value) | |
length = int.to_bytes(len(key), 4, 'big') | |
self._writer.write(length) | |
self._writer.write(key.encode()) | |
await self._writer.drain() | |
data = await self._reader.read(1) | |
if data == Commands.NOT_FOUND.value: | |
return None | |
if data != Commands.OK.value: | |
raise Exception(f'unknown response: {data}') | |
length = int.from_bytes(await self._reader.read(4), 'big') | |
return await self._reader.readexactly(length) | |
async def set(self, key: str, value: str) -> Awaitable[None]: | |
self._writer.write(Commands.SET.value) | |
length = int.to_bytes(len(key), 4, 'big') | |
self._writer.write(length) | |
self._writer.write(key.encode()) | |
length = int.to_bytes(len(value), 4, 'big') | |
self._writer.write(length) | |
self._writer.write(value.encode()) | |
await self._writer.drain() | |
data = await self._reader.read(1) | |
if data != Commands.OK.value: | |
raise Exception(f'unknown response: {data}') | |
async def close(self) -> Awaitable[None]: | |
self._writer.write(Commands.QUIT.value) | |
await self._writer.drain() | |
self._writer.close() | |
class AsyncServer: | |
def __init__(self) -> None: | |
self._cache: Dict[bytes, bytes] = {} | |
async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> Awaitable[None]: | |
while True: | |
command = await reader.readexactly(1) | |
if command == Commands.QUIT.value: | |
writer.close() | |
break | |
if command == Commands.GET.value: | |
length = int.from_bytes(await reader.read(4), 'big') | |
key = await reader.readexactly(length) | |
value = None | |
async with self._lock.reader_lock: | |
if value in self._cache: | |
value = self._cache[key] | |
if not value: | |
writer.write(Commands.NOT_FOUND.value) | |
continue | |
writer.write(Commands.OK.value) | |
writer.write(int.to_bytes(len(value), 4, 'big')) | |
writer.write(value) | |
await writer.drain() | |
continue | |
if command == Commands.SET.value: | |
length = int.from_bytes(await reader.read(4), 'big') | |
key = await reader.readexactly(length) | |
length = int.from_bytes(await reader.read(4), 'big') | |
value = await reader.readexactly(length) | |
async with self._lock.writer_lock: | |
self._cache[key] = value | |
writer.write(Commands.OK.value) | |
await writer.drain() | |
continue | |
raise Exception(f'unknown command: {command}') | |
async def run(self, ip: str, port: int) -> Awaitable[None]: | |
# not completely sure if this is needed :thinking: | |
self._lock = aiorwlock.RWLock() | |
server = await asyncio.start_server(self._handle_client, ip, port) | |
async with server: | |
await server.serve_forever() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from python.util.remote_kv_store import AsyncServer | |
if __name__ == '__main__': | |
server = AsyncServer() | |
asyncio.run(server.run('127.0.0.1', 15555)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from enum import Enum | |
from python.util.remote_kv_store import AsyncClient, SyncClient | |
async def tcp_echo_client(message): | |
client = AsyncClient() | |
await client.connect('127.0.0.1', 15555) | |
print(await client.get('omg')) | |
print(await client.set('omg', 'second')) | |
print(await client.get('omg')) | |
await client.close() | |
asyncio.run(tcp_echo_client('Hello World!')) | |
client = SyncClient() | |
client.connect('127.0.0.1', 15555) | |
print(client.get('omg')) | |
print(client.set('omg', 'second')) | |
print(client.get('omg')) | |
client.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment