Created
January 18, 2023 21:26
-
-
Save tsvikas/d89bcec3f915921e7fa2ea1b58361de6 to your computer and use it in GitHub Desktop.
tools like multiprocess.map()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing as mp | |
import queue | |
import signal | |
from functools import partial | |
from typing import Callable, Iterable | |
def get_queue() -> queue.Queue: | |
""" | |
return a multi-processing queue | |
sub-processes that are created for its implementation should ignore SIGINT (CTL-C), | |
as it should rather be handled exclusively by the master process | |
""" | |
orig_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) | |
q = mp.Manager().Queue() | |
signal.signal(signal.SIGINT, orig_handler) | |
return q | |
def starmap( | |
func: Callable, iterable: Iterable, processes: int | None = None, **mp_kwargs | |
) -> list: | |
""" | |
Improved version of multiprocessing.starmap | |
Asynchronously apply `func` to each element in `iterable`. | |
The elements of the `iterable` are expected to be iterables as well | |
and will be unpacked as arguments. | |
In addition, each worker in the pool is set to ignore keyboard interrupts, | |
to prevent cluttering the screen with irrelevant exceptions data. | |
This function silently terminates the workers on KeyboardInterrupt. | |
:param func: function to map onto. | |
:param iterable: iterable. each element `args` is an iterable that will | |
be unpacked when running func | |
:param processes: number of processes to use. default is 0 = do not use | |
the multiprocessing module at all. | |
:param mp_kwargs: keywords to pass to `multiprocessing.starmap_async` | |
:return: the result of `[func(*args) for args in iterable]` | |
""" | |
if processes is None: | |
return [func(*args) for args in iterable] | |
else: | |
processes = (processes - 1) % mp.cpu_count() + 1 | |
orig_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) | |
with mp.Pool(processes) as pool: | |
signal.signal(signal.SIGINT, orig_handler) | |
# Note: will do pool.terminate on exit | |
return pool.starmap(func, iterable, **mp_kwargs) | |
def starmap_updatable( | |
func: Callable, | |
iterable: Iterable, | |
processes: int | None = None, | |
update_cb: Callable[[int], None] | None = None, | |
**mp_kwargs, | |
) -> list: | |
""" | |
Update-supporting version of multiprocessing.starmap | |
Asynchronously apply `func` to each element in `iterable`. | |
The elements of the `iterable` are expected to be iterables as well | |
and will be unpacked as arguments. | |
`func` should accept an `update_queue` parameter, of type `queue.Queue`. | |
Each time that `func` will call `update_queue.put(i)`, this function | |
will call `update_cb(i)`. | |
In addition, each worker in the pool is set to ignore keyboard interrupts, | |
to prevent cluttering the screen with irrelevant exceptions data. | |
This function silently terminates the workers on KeyboardInterrupt. | |
Usage: | |
import tqdm | |
with tqdm.tqdm(total=the_sum_of_all_updates) as pbar: | |
result = starmap_updatable( | |
slow_func, args_list, processes=10, update_cb=pbar.update) | |
:param func: function to map onto. | |
:param iterable: iterable. each element `args` is an iterable that will | |
be unpacked when running func | |
:param processes: number of processes to use. default is 0 = do not use | |
the multiprocessing module at all. | |
:param update_cb: function to call for each update. for example: | |
:param mp_kwargs: keywords to pass to `multiprocessing.starmap_async` | |
:return: the result of `[func(*args) for args in iterable]` | |
""" | |
if update_cb is None: | |
return starmap(func=func, iterable=iterable, processes=processes, **mp_kwargs) | |
q = get_queue() | |
new_func = partial(func, update_queue=q) | |
if processes is None: | |
res = [] | |
for args in iterable: | |
res += [new_func(*args)] | |
while not q.empty(): | |
i = q.get(timeout=0.1) | |
update_cb(i) | |
return res | |
else: | |
cpu_count = mp.cpu_count() | |
if processes > cpu_count: | |
raise ValueError( | |
f"requested number of processes {processes} is larger than cpu count {cpu_count}" | |
) | |
elif processes > 0: | |
pass | |
elif processes > -cpu_count: | |
processes += cpu_count | |
else: | |
raise ValueError( | |
f"requested number of processes {processes} is not between {-cpu_count} to {cpu_count}" | |
) | |
orig_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) | |
with mp.Pool(processes) as pool: | |
signal.signal(signal.SIGINT, orig_handler) | |
# Note: will do pool.terminate on exit | |
res = pool.starmap_async(new_func, iterable, **mp_kwargs) | |
while (not res.ready()) or (not q.empty()): | |
if not q.empty(): | |
i = q.get(timeout=0.1) | |
update_cb(i) | |
return res.get() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment