Last active
November 17, 2019 17:28
-
-
Save P403n1x87/6ecfc3d0422d2662b4a3f79126033cd4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from multiprocessing import Pool, Queue, Manager, cpu_count | |
from multiprocessing.queues import Empty | |
from threading import Thread | |
from tqdm import tqdm | |
def parallelize(func: callable, iterable: list, processes: int = None) -> list: | |
"""Parallelize the execution of a function over a list. | |
This method runs a given function over chunks of the given list in | |
parallel across multiple processes. The return value is an ordered list | |
with the result of the function on each chunk. | |
Progress of the overall process is shown via a tqdm progress bar. In order | |
to send update messages, the give function must have signature | |
``(chunk, queue)``, where ``chunk`` is the chunk from the given list and | |
``queue`` is an increment queue. To step the progress bar by ``n``, call | |
queue.put(n) | |
It makes sense to wrap a function around this method only if the result | |
of the whole function can be reconstructed from the result on each chunk. | |
""" | |
def update(): | |
pbar = tqdm(total=len(iterable)) | |
while True: | |
try: | |
increment = queue.get(timeout=1) | |
if not increment: | |
break | |
pbar.update(increment) | |
except Empty: | |
pass | |
pbar.close() | |
if processes == 1: | |
raise RuntimeError("Call the function directly!") | |
processes = processes or cpu_count() | |
chunk_size = len(iterable) // processes | |
queue = Manager().Queue() | |
pbar_thread = Thread(target=update) | |
results = [] | |
with Pool(processes=processes) as pool: | |
for i in range(processes - 1): | |
results.append( | |
pool.apply_async( | |
func, args=(iterable[i * chunk_size : (i + 1) * chunk_size], queue) | |
) | |
) | |
results.append( | |
pool.apply_async(func, args=(iterable[(i + 1) * chunk_size :], queue)) | |
) | |
pbar_thread.start() | |
pool.close() | |
pool.join() | |
queue.put(False) | |
pbar_thread.join() | |
return [result.get() for result in results] | |
# ---- EXAMPLE ---- | |
if __name__ == "__main__": | |
import time | |
def parallel_sum(arg, queue): | |
a = 0 | |
for i in arg: | |
a += i | |
time.sleep(0.1) | |
queue.put(1) | |
return a | |
n = 200 | |
result = parallelize(func=parallel_sum, iterable=list(range(n))) | |
assert sum(result) == ((n * (n - 1)) >> 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment