Created
November 16, 2024 09:04
-
-
Save luistung/88804b04f1cd5c8e40629e581562f75b to your computer and use it in GitHub Desktop.
this code implements a parallel task streaming executor using thread pools
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import TypeVar, Callable, Iterable, Tuple, Iterator | |
import concurrent.futures | |
import time | |
import random | |
T = TypeVar('T') | |
R = TypeVar('R') | |
def stream_parallel_tasks( | |
task_fun: Callable[[T], R], | |
iterable: Iterable[T], | |
map_fun: Callable[[T], T], | |
max_workers: int, | |
task_queue_size: int = None | |
) -> Iterator[Tuple[int, R]]: | |
""" | |
Executes tasks in parallel using a ThreadPoolExecutor and yields results in order of completion. | |
Args: | |
task_fun: Function to execute for each item | |
iterable: Input sequence to process | |
map_fun: Function to transform input items before processing | |
max_workers: Maximum number of concurrent threads | |
task_queue_size: Size of internal task queue (defaults to max_workers * 2) | |
Yields: | |
Tuple of (index, result) in order of task completion | |
Raises: | |
ValueError: If max_workers < 1 or map_fun is None | |
Example: | |
>>> def process(x): return x * 2 | |
>>> for idx, result in stream_parallel_tasks(process, range(5), lambda x: x, 2): | |
... print(f"{idx}: {result}") | |
""" | |
if max_workers < 1: | |
raise ValueError("max_workers must be at least 1") | |
if map_fun is None: | |
raise ValueError("map_fun cannot be None") | |
TASK_LIST_CAPACITY = task_queue_size or (max_workers * 2) | |
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
idx_args_iter = enumerate(map(map_fun, iterable)) | |
task_dict = {} | |
is_iter_end = False | |
while not is_iter_end or task_dict: | |
while len(task_dict) < TASK_LIST_CAPACITY and not is_iter_end: | |
try: | |
idx, arg_obj = next(idx_args_iter) | |
except StopIteration: | |
is_iter_end = True | |
else: | |
task = executor.submit(task_fun, arg_obj) | |
task_dict[task] = idx | |
done, not_done = concurrent.futures.wait( | |
task_dict.keys(), | |
return_when=concurrent.futures.FIRST_COMPLETED | |
) | |
for task in done: | |
idx = task_dict.pop(task) | |
try: | |
yield idx, task.result() | |
except Exception as e: | |
# You might want to handle this differently depending on your needs | |
yield idx, None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment