Last active
June 12, 2024 19:28
-
-
Save crosstyan/ad4930f46550d2b1a18b1156fda1cbaf to your computer and use it in GitHub Desktop.
a stupid anyio with multiprocess
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import time | |
from multiprocessing import cpu_count | |
from typing import ( | |
Any, | |
Awaitable, | |
Callable, | |
Final, | |
Generic, | |
Iterable, | |
Optional, | |
Protocol, | |
Tuple, | |
TypedDict, | |
TypeVar, | |
TypeVarTuple, | |
Union, | |
Unpack, | |
cast, | |
) | |
import anyio | |
import multiprocess as mp | |
import numpy as np | |
import signal | |
from loguru import logger | |
from multiprocess.context import BaseContext, DefaultContext, Process, assert_spawning | |
from multiprocess.managers import BaseManager, SharedMemoryManager, SyncManager | |
from multiprocess.pool import ApplyResult, Pool | |
from multiprocess.process import BaseProcess | |
from multiprocess.queues import Empty, Full, Queue | |
from multiprocess.shared_memory import ShareableList, SharedMemory | |
from multiprocess.synchronize import Condition | |
BUF_SIZE: Final = 16 | |
T_Retval = TypeVar("T_Retval") | |
PosArgsT = TypeVarTuple("PosArgsT") | |
def create_process( | |
target: Callable[[Unpack[PosArgsT]], T_Retval], | |
args: Tuple[Unpack[PosArgsT]], | |
name: str | None = None, | |
daemon: bool | None = None, | |
): | |
return Process(target=target, args=args, name=name, daemon=daemon) | |
_exit_handler: Optional[Callable[[int, Any], None]] = None | |
# https://superfastpython.com/multiprocessing-condition-variable-in-python/ | |
def task(id: int, cv: Condition, sm: SharedMemory, sq: Queue, oq: Queue): | |
global _exit_handler | |
logger.info("Starting task {}", id) | |
assert _exit_handler is None | |
def exit_handler(_sig_num: int = 0, _frame: Any = None): | |
logger.info("Task {} done", id) | |
sm.close() | |
_exit_handler = exit_handler | |
signal.signal(signal.SIGTERM, _exit_handler) | |
_buf = np.ndarray((BUF_SIZE,), dtype=np.uint8) | |
# sync queue to indicate that the process is ready | |
sq.put({"id": id}) | |
try: | |
while True: | |
with cv: | |
cv.wait() | |
assert sm.buf is not None | |
temp = np.frombuffer(sm.buf, dtype=np.uint8, count=BUF_SIZE, offset=0) | |
# copy the shared memory to the local buffer | |
_buf[:] = temp[:] | |
s = np.sum(_buf) | |
oq.put({"id": id, "sum": s}) | |
finally: | |
exit_handler() | |
def main(): | |
sync_man = SyncManager() | |
mem_man = SharedMemoryManager() | |
sync_man.start() | |
mem_man.start() | |
ctx = cast(BaseContext, sync_man._ctx) | |
oq = cast(Queue, sync_man.Queue()) # output queue # type: ignore | |
sm = mem_man.SharedMemory(16) | |
cv = Condition(ctx=ctx) | |
count: int = cpu_count() | |
sq = cast(Queue, sync_man.Queue(count)) # sync queue # type: ignore | |
ps = [create_process(target=task, args=(i, cv, sm, sq, oq)) for i in range(count)] | |
for p in ps: | |
p.start() | |
for _ in range(len(ps)): | |
init = sq.get() | |
logger.info("Process {} started", init["id"]) | |
for _ in range(24): | |
assert sm.buf is not None | |
buf = np.frombuffer(sm.buf, dtype=np.uint8, count=BUF_SIZE, offset=0) | |
buf[:] = np.random.randint(0, 255, BUF_SIZE) | |
logger.info("buf={}", buf) | |
with cv: | |
cv.notify() | |
logger.info(f"Result {oq.get()}") | |
logger.info("Main process done") | |
for p in ps: | |
p.terminate() | |
p.join() | |
p.close() | |
try: | |
sm.close() | |
sm.unlink() | |
except BufferError: | |
pass | |
sync_man.shutdown() | |
mem_man.shutdown() | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import ( | |
Any, | |
Awaitable, | |
Callable, | |
Final, | |
Generic, | |
Optional, | |
TypeVar, | |
TypeVarTuple, | |
TypedDict, | |
Union, | |
Unpack, | |
Protocol, | |
cast, | |
) | |
import anyio | |
import multiprocess as mp | |
from loguru import logger | |
from multiprocess.context import BaseContext, DefaultContext, assert_spawning | |
from multiprocess.managers import BaseManager, SharedMemoryManager, SyncManager | |
from multiprocess.pool import ApplyResult, Pool | |
from multiprocess.process import BaseProcess as Process | |
from multiprocess.queues import Empty, Full, Queue | |
from multiprocess.shared_memory import ShareableList, SharedMemory | |
from multiprocess.synchronize import Condition | |
# I don't know what to do now | |
T = TypeVar("T") | |
# https://superfastpython.com/multiprocessing-condition-variable-in-python/ | |
class QueueProxy(Generic[T]): | |
""" | |
An anyio type-safe queue for multiprocessing. | |
This class provides an asynchronous interface to a multiprocessing Queue, | |
allowing it to be used safely with anyio without blocking the event loop. | |
Note | |
------ | |
using the default implementation of get_state and set_state | |
""" | |
_q: Queue | |
_ctx: BaseContext | |
def __init__(self, queue: Queue, ctx: BaseContext): | |
""" | |
Initialize the QueueProxy. | |
:param queue: The multiprocessing Queue to wrap. | |
:param ctx: The multiprocessing context used to create the Queue. | |
""" | |
self._q = queue | |
self._ctx = ctx | |
@staticmethod | |
def from_manager(manager: BaseManager, size: int = 0) -> "QueueProxy[T]": | |
""" | |
Create a new QueueProxy from a multiprocessing Manager. | |
:param manager: The Manager to create the Queue from. | |
:param size: The maximum size of the Queue (default 0 for unlimited). | |
:return: A new QueueProxy instance. | |
""" | |
ctx = manager._ctx # pylint: disable=protected-access | |
return QueueProxy[T](manager.Queue(size), ctx=ctx) | |
def put(self, item: T, block: bool = True, timeout: Optional[float] = None): | |
""" | |
Put an item into the queue, synchronously. | |
This method wraps Queue.put() and has the same behavior. | |
""" | |
self._q.put(item, block=block, timeout=timeout) | |
async def async_put(self, item: T): | |
""" | |
Put an item into the queue asynchronously. | |
If the queue is full, this method waits until a free slot is available | |
before adding the item, without blocking the event loop. | |
""" | |
while True: | |
try: | |
return self._q.put_nowait(item) | |
except Full: | |
await anyio.sleep(0) | |
def get(self, block: bool = True, timeout: Optional[float] = None) -> T: | |
""" | |
Get an item from the queue, synchronously. | |
This method wraps Queue.get() and has the same behavior. | |
""" | |
return cast(T, self._q.get(block=block, timeout=timeout)) | |
async def async_get(self): | |
""" | |
Get an item from the queue asynchronously. | |
If the queue is empty, this method waits until an item is available | |
without blocking the event loop. | |
""" | |
while True: | |
try: | |
return cast(T, self._q.get_nowait()) | |
except Empty: | |
await anyio.sleep(0) | |
def put_nowait(self, item: T): | |
"""Put an item into the queue if a free slot is immediately available.""" | |
self._q.put_nowait(item) | |
def get_nowait(self) -> T: | |
"""Get an item from the queue if one is immediately available.""" | |
return cast(T, self._q.get_nowait()) | |
@property | |
def queue(self) -> Queue: | |
""" | |
Get the underlying multiprocessing Queue. | |
Use this property when passing the Queue to a multiprocessing Process. | |
""" | |
return self._q | |
async def __aiter__(self): | |
""" | |
Asynchronous iterator interface to get items from the queue. | |
This allows using the queue with `async for` without blocking the event loop. | |
""" | |
while True: | |
try: | |
el = self._q.get(block=False) | |
yield cast(T, el) | |
except Empty: | |
# https://superfastpython.com/what-is-asyncio-sleep-zero/ | |
# yield control to the event loop | |
await anyio.sleep(0) | |
async def __anext__(self): | |
return await self.async_get() | |
class ApplyResultLike(Protocol, Generic[T]): | |
def ready(self) -> bool: ... | |
def successful(self) -> bool: ... | |
def get(self, timeout: Optional[float]) -> T: ... | |
def wait(self, timeout: Optional[float]) -> None: ... | |
async def await_result(result: ApplyResultLike[T]) -> T: | |
""" | |
wrap an ApplyResult to an awaitable | |
""" | |
while not result.ready(): | |
await anyio.sleep(0) | |
return result.get() | |
_a: Optional[int] = None | |
""" | |
this variable is expected to be unique in each process | |
""" | |
def init(): | |
global _a | |
assert _a is None, "inited" | |
_a = 1 | |
logger.info("init {}", _a) | |
def task(i: int, q: QueueProxy[int]): | |
global _a | |
assert _a is not None | |
_a += i | |
q.put(i, timeout=3) | |
return _a | |
TIMEOUT = 5 | |
QUEUE_SIZE = 24 | |
T_Retval = TypeVar("T_Retval") | |
PosArgsT = TypeVarTuple("PosArgsT") | |
def safe_apply_sync( | |
pool: Pool, | |
func: Callable[[Unpack[PosArgsT]], T_Retval], | |
*args: Unpack[PosArgsT], | |
callback: Optional[Callable[[T_Retval], None]] = None, | |
error_callback: Optional[Callable[[Any], None]] = None, | |
) -> ApplyResultLike[T_Retval]: | |
""" | |
A type-safe wrapper around `Pool.apply_async()` | |
""" | |
return pool.apply_async( | |
func=func, | |
args=args, | |
callback=callback, | |
error_callback=error_callback, | |
) | |
def main(): | |
count: int = mp.cpu_count() | |
logger.info("cpu count is {}", count) | |
man = SyncManager() | |
man.start() | |
q = QueueProxy[int].from_manager(man, QUEUE_SIZE) | |
p = Pool(processes=count, initializer=init) | |
for _ in range(1_000): | |
# https://superfastpython.com/multiprocessing-pool-asyncresult/ | |
ar = safe_apply_sync(p, task, 1, q) | |
async def consumer(): | |
await q.async_put(10) | |
first = await q.async_get() | |
logger.info("first={}", first) | |
acc = 0 | |
with anyio.move_on_after(TIMEOUT) as cancel_scope: | |
async for i in q: | |
acc += i | |
cancel_scope.deadline = anyio.current_time() + TIMEOUT | |
logger.info("acc={}", acc) | |
anyio.run(consumer) | |
p.close() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment