Skip to content

Instantly share code, notes, and snippets.

@tsvikas
Created January 18, 2023 21:26
Show Gist options
  • Save tsvikas/d89bcec3f915921e7fa2ea1b58361de6 to your computer and use it in GitHub Desktop.
Save tsvikas/d89bcec3f915921e7fa2ea1b58361de6 to your computer and use it in GitHub Desktop.
tools like multiprocess.map()
import multiprocessing as mp
import queue
import signal
from functools import partial
from typing import Callable, Iterable
def get_queue() -> queue.Queue:
"""
return a multi-processing queue
sub-processes that are created for its implementation should ignore SIGINT (CTL-C),
as it should rather be handled exclusively by the master process
"""
orig_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
q = mp.Manager().Queue()
signal.signal(signal.SIGINT, orig_handler)
return q
def starmap(
func: Callable, iterable: Iterable, processes: int | None = None, **mp_kwargs
) -> list:
"""
Improved version of multiprocessing.starmap
Asynchronously apply `func` to each element in `iterable`.
The elements of the `iterable` are expected to be iterables as well
and will be unpacked as arguments.
In addition, each worker in the pool is set to ignore keyboard interrupts,
to prevent cluttering the screen with irrelevant exceptions data.
This function silently terminates the workers on KeyboardInterrupt.
:param func: function to map onto.
:param iterable: iterable. each element `args` is an iterable that will
be unpacked when running func
:param processes: number of processes to use. default is 0 = do not use
the multiprocessing module at all.
:param mp_kwargs: keywords to pass to `multiprocessing.starmap_async`
:return: the result of `[func(*args) for args in iterable]`
"""
if processes is None:
return [func(*args) for args in iterable]
else:
processes = (processes - 1) % mp.cpu_count() + 1
orig_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
with mp.Pool(processes) as pool:
signal.signal(signal.SIGINT, orig_handler)
# Note: will do pool.terminate on exit
return pool.starmap(func, iterable, **mp_kwargs)
def starmap_updatable(
func: Callable,
iterable: Iterable,
processes: int | None = None,
update_cb: Callable[[int], None] | None = None,
**mp_kwargs,
) -> list:
"""
Update-supporting version of multiprocessing.starmap
Asynchronously apply `func` to each element in `iterable`.
The elements of the `iterable` are expected to be iterables as well
and will be unpacked as arguments.
`func` should accept an `update_queue` parameter, of type `queue.Queue`.
Each time that `func` will call `update_queue.put(i)`, this function
will call `update_cb(i)`.
In addition, each worker in the pool is set to ignore keyboard interrupts,
to prevent cluttering the screen with irrelevant exceptions data.
This function silently terminates the workers on KeyboardInterrupt.
Usage:
import tqdm
with tqdm.tqdm(total=the_sum_of_all_updates) as pbar:
result = starmap_updatable(
slow_func, args_list, processes=10, update_cb=pbar.update)
:param func: function to map onto.
:param iterable: iterable. each element `args` is an iterable that will
be unpacked when running func
:param processes: number of processes to use. default is 0 = do not use
the multiprocessing module at all.
:param update_cb: function to call for each update. for example:
:param mp_kwargs: keywords to pass to `multiprocessing.starmap_async`
:return: the result of `[func(*args) for args in iterable]`
"""
if update_cb is None:
return starmap(func=func, iterable=iterable, processes=processes, **mp_kwargs)
q = get_queue()
new_func = partial(func, update_queue=q)
if processes is None:
res = []
for args in iterable:
res += [new_func(*args)]
while not q.empty():
i = q.get(timeout=0.1)
update_cb(i)
return res
else:
cpu_count = mp.cpu_count()
if processes > cpu_count:
raise ValueError(
f"requested number of processes {processes} is larger than cpu count {cpu_count}"
)
elif processes > 0:
pass
elif processes > -cpu_count:
processes += cpu_count
else:
raise ValueError(
f"requested number of processes {processes} is not between {-cpu_count} to {cpu_count}"
)
orig_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
with mp.Pool(processes) as pool:
signal.signal(signal.SIGINT, orig_handler)
# Note: will do pool.terminate on exit
res = pool.starmap_async(new_func, iterable, **mp_kwargs)
while (not res.ready()) or (not q.empty()):
if not q.empty():
i = q.get(timeout=0.1)
update_cb(i)
return res.get()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment