Last active
February 10, 2025 16:38
-
-
Save Hammer2900/45698014028ce1f4373141a1aa3692b5 to your computer and use it in GitHub Desktop.
A simple asynchronous executor based on asyncio. Similar to concurrent.futures.ThreadPoolExecutor, but uses tasks instead of threads. Allows you to run asynchronous functions in a pool, limiting the number of concurrent tasks. Supports initializer for workers and graceful shutdown. I wrote it for myself, but maybe it will be useful to someone.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import contextvars | |
from typing import Callable, Awaitable, Iterable, Any, AsyncIterator, Protocol, runtime_checkable | |
@runtime_checkable | |
class Initializer(Protocol): | |
async def __call__(self, *args: Any) -> None: ... | |
@runtime_checkable | |
class TaskFunc(Protocol): | |
async def __call__(self, *args: Any, **kwargs: Any) -> Any: ... | |
class ExecutorError(Exception): | |
"""Base exception for the Executor.""" | |
pass | |
class ShutdownError(ExecutorError): | |
"""Raised when trying to submit a task after shutdown.""" | |
pass | |
class InitializationError(ExecutorError): | |
"""Raised if the initializer fails.""" | |
pass | |
class WorkerError(ExecutorError): | |
"""Raised if a worker encounters an unhandled exception.""" | |
pass | |
class Executor: | |
""" | |
An asynchronous executor that manages a pool of worker tasks. | |
This executor is similar in concept to `concurrent.futures.ThreadPoolExecutor` | |
but uses asyncio tasks instead of threads. It provides a simple interface | |
for submitting asynchronous tasks and retrieving their results. | |
Args: | |
max_workers: The maximum number of worker tasks to create. | |
task_name_prefix: A prefix for the names of the worker tasks. | |
initializer: An optional asynchronous callable that will be executed | |
once in each worker task before it starts processing jobs. | |
initargs: Arguments to pass to the initializer. | |
Raises: | |
ValueError: If `max_workers` is not a positive integer. | |
""" | |
def __init__( | |
self, | |
max_workers: int = 100, | |
task_name_prefix: str = '', | |
initializer: Initializer | None = None, | |
initargs: tuple[Any, ...] = (), | |
) -> None: | |
if max_workers <= 0: | |
raise ValueError('max_workers must be greater than 0') | |
self._max_workers = max_workers | |
self._task_name_prefix = task_name_prefix or 'Executor' | |
self._initializer = initializer | |
self._initargs = initargs | |
self._jobs: asyncio.Queue[Callable[[], Awaitable[Any]]] = asyncio.Queue() | |
self._tasks: list[asyncio.Task] = [] | |
self._shutdown = False | |
self._initialized = False | |
self._init_task: asyncio.Task | None = None | |
async def submit( | |
self, | |
fn: TaskFunc, | |
/, | |
*args: Any, | |
**kwargs: Any, | |
) -> asyncio.Future[Any]: | |
""" | |
Submit an asynchronous task to the executor. | |
Args: | |
fn: The asynchronous callable to execute. | |
*args: Positional arguments to pass to the callable. | |
**kwargs: Keyword arguments to pass to the callable. | |
Returns: | |
An asyncio.Future representing the result of the task. | |
Raises: | |
ShutdownError: If the executor has been shut down. | |
""" | |
return await self.submit_with_context(None, fn, *args, **kwargs) | |
async def submit_with_context( | |
self, | |
context: contextvars.Context | None, | |
fn: TaskFunc, | |
/, | |
*args: Any, | |
**kwargs: Any, | |
) -> asyncio.Future[Any]: | |
""" | |
Submit an asynchronous task to the executor with a specific context. | |
Args: | |
context: The contextvars.Context to run the task in. If None, | |
the current context is used. | |
fn: The asynchronous callable to execute. | |
*args: Positional arguments to pass to the callable. | |
**kwargs: Keyword arguments to pass to the callable. | |
Returns: | |
An asyncio.Future representing the result of the task. | |
Raises: | |
ShutdownError: If the executor has been shut down. | |
""" | |
if self._shutdown: | |
raise ShutdownError('Cannot schedule new futures after shutdown') | |
await self._initialize() # Await the async _initialize | |
loop = asyncio.get_running_loop() | |
future = loop.create_future() | |
async def wrapped_fn(): | |
if context: | |
context.run(_set_context_vars, contextvars.copy_context()) | |
try: | |
result = await fn(*args, **kwargs) | |
if not future.done(): | |
future.set_result(result) | |
except BaseException as e: | |
if not future.done(): | |
future.set_exception(e) | |
self._jobs.put_nowait(wrapped_fn) | |
return future | |
def map( | |
self, | |
fn: TaskFunc, | |
/, | |
*iterables: Iterable[Any], | |
) -> AsyncIterator[Any]: | |
""" | |
Apply an asynchronous function to each item in iterables. | |
This is similar to the built-in `map` function, but operates | |
asynchronously. | |
Args: | |
fn: The asynchronous callable to apply. | |
*iterables: Iterables yielding arguments for the function. | |
Returns: | |
An asynchronous iterator yielding the results. | |
Raises: | |
ShutdownError: If the executor has been shut down. | |
""" | |
futures = [self.submit(fn, *args) for args in zip(*iterables, strict=True)] | |
async def result_iterator(): | |
try: | |
for future in futures: | |
yield await future | |
finally: | |
for future in futures: | |
if not future.done(): | |
future.cancel() | |
return result_iterator() | |
async def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: | |
""" | |
Shut down the executor. | |
Args: | |
wait: If True, wait for all pending tasks to complete. | |
cancel_futures: If True, cancel all pending tasks. | |
Raises: | |
Exception: If errors occur during worker shutdown. | |
""" | |
self._shutdown = True | |
if not self._initialized: | |
return | |
if cancel_futures: | |
while not self._jobs.empty(): | |
try: | |
self._jobs.get_nowait() | |
except asyncio.QueueEmpty: | |
pass | |
for task in self._tasks: | |
if not task.done(): | |
task.cancel() | |
if wait: | |
to_await = self._tasks[:] | |
if self._init_task: | |
to_await.append(self._init_task) # Add init_task to await | |
results = await asyncio.gather(*to_await, return_exceptions=True) | |
exceptions = [res for res in results if isinstance(res, Exception)] | |
if exceptions: | |
raise ExceptionGroup('Exceptions occurred during shutdown', exceptions) | |
async def _initialize(self) -> None: | |
"""Initialize the executor by creating worker tasks.""" | |
if self._initialized: | |
return | |
self._initialized = True | |
if self._initializer: | |
# Now we can just await the initializer directly. | |
try: | |
await self._initializer(*self._initargs) | |
except Exception as e: | |
raise InitializationError(f'Initializer failed: {e}') from e | |
for i in range(self._max_workers): | |
task_name = f'{self._task_name_prefix}-{i}' | |
task = asyncio.create_task(self._worker(), name=task_name) | |
self._tasks.append(task) | |
async def _worker(self) -> None: | |
"""Worker task: repeatedly gets jobs from the queue and executes them.""" | |
while True: | |
try: | |
job = await self._jobs.get() | |
if job is None: # Shutdown signal | |
break | |
await job() | |
except asyncio.CancelledError: | |
# Task was cancelled, exit gracefully | |
return | |
except Exception as e: | |
# Catch and report *all* exceptions from worker tasks. | |
print(f'Worker encountered an exception: {e}') | |
raise WorkerError(f'Worker failed: {e}') from e | |
def _set_context_vars(context: contextvars.Context): | |
"""Helper function to set context variables.""" | |
for var, value in context.items(): | |
var.set(value) | |
async def my_task(task_id: int): | |
"""Example task that simulates work.""" | |
for i in range(1, 11): | |
print(f'Task {task_id}: Count {i}') | |
await asyncio.sleep(0.1) | |
print(f'Task {task_id}: Finished') | |
async def my_initializer(): | |
"""Example initializer that simulates setup.""" | |
print('Initializing...') | |
await asyncio.sleep(0.5) | |
print('Initialization complete.') | |
async def main(): | |
"""Main function to demonstrate the Executor.""" | |
executor = Executor(max_workers=10, task_name_prefix='MyExecutor', initializer=my_initializer) | |
futures = [] | |
for i in range(13): | |
future = await executor.submit(my_task, i + 1) | |
futures.append(future) | |
await asyncio.gather(*futures) | |
await executor.shutdown() | |
if __name__ == '__main__': | |
asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import contextvars | |
from typing import Callable, Awaitable, Iterable, Any, AsyncIterator | |
class Executor: | |
def __init__( | |
self, | |
max_workers: int = 100, | |
task_name_prefix: str = '', | |
initializer: Callable[..., Awaitable[None]] | None = None, | |
initargs: tuple[Any, ...] = (), | |
) -> None: | |
if max_workers <= 0: | |
raise ValueError('max_workers must be greater than 0') | |
self._max_workers = max_workers | |
self._task_name_prefix = task_name_prefix or 'Executor' | |
self._initializer = initializer | |
self._initargs = initargs | |
self._jobs: asyncio.Queue[Callable[[], Awaitable[Any]]] = asyncio.Queue() | |
self._tasks: list[asyncio.Task] = [] | |
self._shutdown = False | |
self._initialized = False | |
self._init_task: asyncio.Task | None = None # Store the init task | |
def submit( | |
self, | |
fn: Callable[..., Awaitable[Any]], | |
/, | |
*args: Any, | |
**kwargs: Any, | |
) -> asyncio.Future[Any]: | |
return self.submit_with_context(None, fn, *args, **kwargs) | |
def submit_with_context( | |
self, | |
context: contextvars.Context | None, | |
fn: Callable[..., Awaitable[Any]], | |
/, | |
*args: Any, | |
**kwargs: Any, | |
) -> asyncio.Future[Any]: | |
if self._shutdown: | |
raise RuntimeError('Cannot schedule new futures after shutdown') | |
self._lazy_init() | |
loop = asyncio.get_running_loop() | |
future = loop.create_future() | |
async def wrapped_fn(): | |
if context: | |
context.run(_set_context_vars, contextvars.copy_context()) | |
try: | |
result = await fn(*args, **kwargs) | |
if not future.done(): | |
future.set_result(result) | |
except BaseException as e: | |
if not future.done(): | |
future.set_exception(e) | |
self._jobs.put_nowait(wrapped_fn) | |
return future | |
def map( | |
self, | |
fn: Callable[..., Awaitable[Any]], | |
/, | |
*iterables: Iterable[Any], | |
) -> AsyncIterator[Any]: | |
futures = [self.submit(fn, *args) for args in zip(*iterables, strict=True)] | |
async def result_iterator(): | |
try: | |
for future in futures: | |
yield await future | |
finally: | |
for future in futures: | |
if not future.done(): | |
future.cancel() | |
return result_iterator() | |
async def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: | |
self._shutdown = True | |
if not self._initialized: | |
return | |
if cancel_futures: | |
while not self._jobs.empty(): | |
try: | |
self._jobs.get_nowait() | |
except asyncio.QueueEmpty: | |
pass | |
for task in self._tasks: | |
if not task.done(): # Check if task is done before canceling | |
task.cancel() | |
# Instead of self._jobs.shutdown(), we'll let the workers exit naturally | |
# by finishing the queue, or by being cancelled. We *cannot* use | |
# self._jobs.join() here, because the workers might be blocked on | |
# self._jobs.get(), and we'd deadlock. | |
if wait: | |
to_await = self._tasks[:] # Copy the list | |
if self._init_task: | |
to_await.append(self._init_task) # Add init_task to await | |
await asyncio.gather(*to_await, return_exceptions=True) | |
# No need to clear self._tasks, they've already been gathered. | |
def _lazy_init(self) -> None: | |
if self._initialized: | |
return | |
self._initialized = True | |
if self._initializer: | |
# Create a task for the initializer, but don't await it yet. | |
self._init_task = asyncio.create_task(self._initializer(*self._initargs)) | |
for i in range(self._max_workers): | |
task_name = f'{self._task_name_prefix}-{i}' | |
task = asyncio.create_task(self._worker(), name=task_name) | |
self._tasks.append(task) | |
async def _worker(self) -> None: | |
while True: | |
try: | |
job = await self._jobs.get() | |
if job is None: # Shutdown signal | |
break | |
await job() | |
except asyncio.CancelledError: | |
# Task was cancelled, exit gracefully | |
return | |
except Exception as e: | |
print(f'Worker encountered an exception: {e}') | |
# Consider logging the exception properly here. | |
def _set_context_vars(context: contextvars.Context): | |
for var, value in context.items(): | |
var.set(value) | |
async def my_task(task_id: int): | |
for i in range(1, 11): | |
print(f'Task {task_id}: Count {i}') | |
await asyncio.sleep(3.1) # Simulate some work | |
print(f'Task {task_id}: Finished') | |
async def my_initializer(): | |
print('Initializing...') | |
await asyncio.sleep(0.5) # Simulate initialization time | |
print('Initialization complete.') | |
async def main(): | |
executor = Executor(max_workers=10, task_name_prefix='MyExecutor', initializer=my_initializer) | |
futures = [] | |
for i in range(13): | |
future = executor.submit(my_task, i + 1) | |
futures.append(future) | |
await asyncio.gather(*futures) # Wait for all tasks to complete | |
await executor.shutdown() | |
if __name__ == '__main__': | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment