Skip to content

Instantly share code, notes, and snippets.

@Hammer2900
Last active February 10, 2025 16:38
Show Gist options
  • Save Hammer2900/45698014028ce1f4373141a1aa3692b5 to your computer and use it in GitHub Desktop.
Save Hammer2900/45698014028ce1f4373141a1aa3692b5 to your computer and use it in GitHub Desktop.
A simple asynchronous executor based on asyncio. Similar to concurrent.futures.ThreadPoolExecutor, but uses tasks instead of threads. Allows you to run asynchronous functions in a pool, limiting the number of concurrent tasks. Supports initializer for workers and graceful shutdown. I wrote it for myself, but maybe it will be useful to someone.
import asyncio
import contextvars
from typing import Callable, Awaitable, Iterable, Any, AsyncIterator, Protocol, runtime_checkable
@runtime_checkable
class Initializer(Protocol):
async def __call__(self, *args: Any) -> None: ...
@runtime_checkable
class TaskFunc(Protocol):
async def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
class ExecutorError(Exception):
"""Base exception for the Executor."""
pass
class ShutdownError(ExecutorError):
"""Raised when trying to submit a task after shutdown."""
pass
class InitializationError(ExecutorError):
"""Raised if the initializer fails."""
pass
class WorkerError(ExecutorError):
"""Raised if a worker encounters an unhandled exception."""
pass
class Executor:
"""
An asynchronous executor that manages a pool of worker tasks.
This executor is similar in concept to `concurrent.futures.ThreadPoolExecutor`
but uses asyncio tasks instead of threads. It provides a simple interface
for submitting asynchronous tasks and retrieving their results.
Args:
max_workers: The maximum number of worker tasks to create.
task_name_prefix: A prefix for the names of the worker tasks.
initializer: An optional asynchronous callable that will be executed
once in each worker task before it starts processing jobs.
initargs: Arguments to pass to the initializer.
Raises:
ValueError: If `max_workers` is not a positive integer.
"""
def __init__(
self,
max_workers: int = 100,
task_name_prefix: str = '',
initializer: Initializer | None = None,
initargs: tuple[Any, ...] = (),
) -> None:
if max_workers <= 0:
raise ValueError('max_workers must be greater than 0')
self._max_workers = max_workers
self._task_name_prefix = task_name_prefix or 'Executor'
self._initializer = initializer
self._initargs = initargs
self._jobs: asyncio.Queue[Callable[[], Awaitable[Any]]] = asyncio.Queue()
self._tasks: list[asyncio.Task] = []
self._shutdown = False
self._initialized = False
self._init_task: asyncio.Task | None = None
async def submit(
self,
fn: TaskFunc,
/,
*args: Any,
**kwargs: Any,
) -> asyncio.Future[Any]:
"""
Submit an asynchronous task to the executor.
Args:
fn: The asynchronous callable to execute.
*args: Positional arguments to pass to the callable.
**kwargs: Keyword arguments to pass to the callable.
Returns:
An asyncio.Future representing the result of the task.
Raises:
ShutdownError: If the executor has been shut down.
"""
return await self.submit_with_context(None, fn, *args, **kwargs)
async def submit_with_context(
self,
context: contextvars.Context | None,
fn: TaskFunc,
/,
*args: Any,
**kwargs: Any,
) -> asyncio.Future[Any]:
"""
Submit an asynchronous task to the executor with a specific context.
Args:
context: The contextvars.Context to run the task in. If None,
the current context is used.
fn: The asynchronous callable to execute.
*args: Positional arguments to pass to the callable.
**kwargs: Keyword arguments to pass to the callable.
Returns:
An asyncio.Future representing the result of the task.
Raises:
ShutdownError: If the executor has been shut down.
"""
if self._shutdown:
raise ShutdownError('Cannot schedule new futures after shutdown')
await self._initialize() # Await the async _initialize
loop = asyncio.get_running_loop()
future = loop.create_future()
async def wrapped_fn():
if context:
context.run(_set_context_vars, contextvars.copy_context())
try:
result = await fn(*args, **kwargs)
if not future.done():
future.set_result(result)
except BaseException as e:
if not future.done():
future.set_exception(e)
self._jobs.put_nowait(wrapped_fn)
return future
def map(
self,
fn: TaskFunc,
/,
*iterables: Iterable[Any],
) -> AsyncIterator[Any]:
"""
Apply an asynchronous function to each item in iterables.
This is similar to the built-in `map` function, but operates
asynchronously.
Args:
fn: The asynchronous callable to apply.
*iterables: Iterables yielding arguments for the function.
Returns:
An asynchronous iterator yielding the results.
Raises:
ShutdownError: If the executor has been shut down.
"""
futures = [self.submit(fn, *args) for args in zip(*iterables, strict=True)]
async def result_iterator():
try:
for future in futures:
yield await future
finally:
for future in futures:
if not future.done():
future.cancel()
return result_iterator()
async def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
"""
Shut down the executor.
Args:
wait: If True, wait for all pending tasks to complete.
cancel_futures: If True, cancel all pending tasks.
Raises:
Exception: If errors occur during worker shutdown.
"""
self._shutdown = True
if not self._initialized:
return
if cancel_futures:
while not self._jobs.empty():
try:
self._jobs.get_nowait()
except asyncio.QueueEmpty:
pass
for task in self._tasks:
if not task.done():
task.cancel()
if wait:
to_await = self._tasks[:]
if self._init_task:
to_await.append(self._init_task) # Add init_task to await
results = await asyncio.gather(*to_await, return_exceptions=True)
exceptions = [res for res in results if isinstance(res, Exception)]
if exceptions:
raise ExceptionGroup('Exceptions occurred during shutdown', exceptions)
async def _initialize(self) -> None:
"""Initialize the executor by creating worker tasks."""
if self._initialized:
return
self._initialized = True
if self._initializer:
# Now we can just await the initializer directly.
try:
await self._initializer(*self._initargs)
except Exception as e:
raise InitializationError(f'Initializer failed: {e}') from e
for i in range(self._max_workers):
task_name = f'{self._task_name_prefix}-{i}'
task = asyncio.create_task(self._worker(), name=task_name)
self._tasks.append(task)
async def _worker(self) -> None:
"""Worker task: repeatedly gets jobs from the queue and executes them."""
while True:
try:
job = await self._jobs.get()
if job is None: # Shutdown signal
break
await job()
except asyncio.CancelledError:
# Task was cancelled, exit gracefully
return
except Exception as e:
# Catch and report *all* exceptions from worker tasks.
print(f'Worker encountered an exception: {e}')
raise WorkerError(f'Worker failed: {e}') from e
def _set_context_vars(context: contextvars.Context):
"""Helper function to set context variables."""
for var, value in context.items():
var.set(value)
async def my_task(task_id: int):
"""Example task that simulates work."""
for i in range(1, 11):
print(f'Task {task_id}: Count {i}')
await asyncio.sleep(0.1)
print(f'Task {task_id}: Finished')
async def my_initializer():
"""Example initializer that simulates setup."""
print('Initializing...')
await asyncio.sleep(0.5)
print('Initialization complete.')
async def main():
"""Main function to demonstrate the Executor."""
executor = Executor(max_workers=10, task_name_prefix='MyExecutor', initializer=my_initializer)
futures = []
for i in range(13):
future = await executor.submit(my_task, i + 1)
futures.append(future)
await asyncio.gather(*futures)
await executor.shutdown()
if __name__ == '__main__':
asyncio.run(main())
import asyncio
import contextvars
from typing import Callable, Awaitable, Iterable, Any, AsyncIterator
class Executor:
def __init__(
self,
max_workers: int = 100,
task_name_prefix: str = '',
initializer: Callable[..., Awaitable[None]] | None = None,
initargs: tuple[Any, ...] = (),
) -> None:
if max_workers <= 0:
raise ValueError('max_workers must be greater than 0')
self._max_workers = max_workers
self._task_name_prefix = task_name_prefix or 'Executor'
self._initializer = initializer
self._initargs = initargs
self._jobs: asyncio.Queue[Callable[[], Awaitable[Any]]] = asyncio.Queue()
self._tasks: list[asyncio.Task] = []
self._shutdown = False
self._initialized = False
self._init_task: asyncio.Task | None = None # Store the init task
def submit(
self,
fn: Callable[..., Awaitable[Any]],
/,
*args: Any,
**kwargs: Any,
) -> asyncio.Future[Any]:
return self.submit_with_context(None, fn, *args, **kwargs)
def submit_with_context(
self,
context: contextvars.Context | None,
fn: Callable[..., Awaitable[Any]],
/,
*args: Any,
**kwargs: Any,
) -> asyncio.Future[Any]:
if self._shutdown:
raise RuntimeError('Cannot schedule new futures after shutdown')
self._lazy_init()
loop = asyncio.get_running_loop()
future = loop.create_future()
async def wrapped_fn():
if context:
context.run(_set_context_vars, contextvars.copy_context())
try:
result = await fn(*args, **kwargs)
if not future.done():
future.set_result(result)
except BaseException as e:
if not future.done():
future.set_exception(e)
self._jobs.put_nowait(wrapped_fn)
return future
def map(
self,
fn: Callable[..., Awaitable[Any]],
/,
*iterables: Iterable[Any],
) -> AsyncIterator[Any]:
futures = [self.submit(fn, *args) for args in zip(*iterables, strict=True)]
async def result_iterator():
try:
for future in futures:
yield await future
finally:
for future in futures:
if not future.done():
future.cancel()
return result_iterator()
async def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
self._shutdown = True
if not self._initialized:
return
if cancel_futures:
while not self._jobs.empty():
try:
self._jobs.get_nowait()
except asyncio.QueueEmpty:
pass
for task in self._tasks:
if not task.done(): # Check if task is done before canceling
task.cancel()
# Instead of self._jobs.shutdown(), we'll let the workers exit naturally
# by finishing the queue, or by being cancelled. We *cannot* use
# self._jobs.join() here, because the workers might be blocked on
# self._jobs.get(), and we'd deadlock.
if wait:
to_await = self._tasks[:] # Copy the list
if self._init_task:
to_await.append(self._init_task) # Add init_task to await
await asyncio.gather(*to_await, return_exceptions=True)
# No need to clear self._tasks, they've already been gathered.
def _lazy_init(self) -> None:
if self._initialized:
return
self._initialized = True
if self._initializer:
# Create a task for the initializer, but don't await it yet.
self._init_task = asyncio.create_task(self._initializer(*self._initargs))
for i in range(self._max_workers):
task_name = f'{self._task_name_prefix}-{i}'
task = asyncio.create_task(self._worker(), name=task_name)
self._tasks.append(task)
async def _worker(self) -> None:
while True:
try:
job = await self._jobs.get()
if job is None: # Shutdown signal
break
await job()
except asyncio.CancelledError:
# Task was cancelled, exit gracefully
return
except Exception as e:
print(f'Worker encountered an exception: {e}')
# Consider logging the exception properly here.
def _set_context_vars(context: contextvars.Context):
for var, value in context.items():
var.set(value)
async def my_task(task_id: int):
for i in range(1, 11):
print(f'Task {task_id}: Count {i}')
await asyncio.sleep(3.1) # Simulate some work
print(f'Task {task_id}: Finished')
async def my_initializer():
print('Initializing...')
await asyncio.sleep(0.5) # Simulate initialization time
print('Initialization complete.')
async def main():
executor = Executor(max_workers=10, task_name_prefix='MyExecutor', initializer=my_initializer)
futures = []
for i in range(13):
future = executor.submit(my_task, i + 1)
futures.append(future)
await asyncio.gather(*futures) # Wait for all tasks to complete
await executor.shutdown()
if __name__ == '__main__':
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment