Last active
October 23, 2021 21:47
-
-
Save amencke/0cffa2c2df55825af0b94e13dd316738 to your computer and use it in GitHub Desktop.
threadsafe connection pool implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import concurrent.futures | |
import threading | |
import uuid | |
from timeit import default_timer as timer | |
class NoConnectionAvailableError(Exception): | |
def __init__(self, msg): | |
super(NoConnectionAvailableError, self).__init__(msg) | |
class ClientConnectionError(Exception): | |
def __init__(self, msg): | |
super(ClientConnectionError, self).__init__(msg) | |
print_lock = threading.Lock() | |
completed_work = 0 | |
def threadsafe_print(*args, **kwargs): | |
with print_lock: | |
print(*args, **kwargs) | |
class Connection(object): | |
def __init__(self, id_): | |
self._id = id_ | |
async def work(self): | |
global completed_work | |
completed_work += 1 | |
threadsafe_print(f"Connection {self._id} doing work...") | |
try: | |
await asyncio.sleep(1) # I/O bound network operation | |
except Exception: | |
raise ClientConnectionError("Something went wrong") | |
class ConnectionPool(object): | |
def __init__(self, max_connections, timeout=1): | |
self._pool_sema = threading.BoundedSemaphore(max_connections) | |
self._connections = [] | |
self._connection_tracker = set() | |
self._timeout = timeout | |
def get(self): | |
# block until a connection becomes available | |
if not self._pool_sema.acquire(blocking=True, timeout=self._timeout): | |
raise NoConnectionAvailableError("No connection available!") | |
if self._connections: | |
return self._connections.pop() | |
conn = Connection(uuid.uuid4()) | |
self._connection_tracker.add(conn) | |
return conn | |
def release(self, conn): | |
assert(conn in self._connection_tracker) | |
self._connections.append(conn) | |
# There are never more than min(max_connections, max_workers) connections created | |
# threadsafe_print(sorted([id(conn) for conn in self._connections])) | |
self._pool_sema.release() | |
async def work_and_release(pool): | |
try: | |
conn = pool.get() | |
except NoConnectionAvailableError: | |
threadsafe_print("Handling connection pool error...") | |
return | |
async def _do_work(): | |
try: | |
await conn.work() | |
except ClientConnectionError: | |
threadsafe_print("Handling client connection error...") | |
return | |
coros = [_do_work() for _ in range(1000)] | |
await asyncio.gather(*coros) | |
pool.release(conn) | |
if __name__ == '__main__': | |
pool = ConnectionPool(max_connections=10, timeout=3) | |
start = timer() | |
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: | |
for _ in range(40): | |
f = executor.submit(lambda: asyncio.run(work_and_release(pool))) | |
end = timer() | |
print(f"time taken: {end - start}") # ~ requested connections / min(max connections, max workers) | |
print(f"completed work: {completed_work}") | |
# 40 iterations, 1000 corutines per iteration, max_connections=10, max_workers=8, 1 second sleep | |
# ... | |
# Connection 938aa726-57cf-4737-86a5-a30fb02f5668 doing work... | |
# Connection fad94190-84a8-4497-b53a-308f24c1d1f0 doing work... | |
# Connection ff21e605-da6c-4748-90d1-89c9ef60d9db doing work... | |
# Connection 1d224283-905c-4abf-ac96-bc9301128450 doing work... | |
# Connection 41afe594-055a-4293-b240-14959467c5bc doing work... | |
# time taken: 6.165483556 | |
# completed work: 40000 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment