Last active
October 5, 2020 17:19
-
-
Save zhihongliuus/e1e6240329d954a15521de410bd4054b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import traceback | |
import multiprocessing | |
from collections import namedtuple | |
"""Implement a wrapper to access instance which is running in separated process. | |
The reason to implement it because due to GIL, a python interpreter can only run on | |
one CPU core. If one task is running with heavy CPU usage for long time, it is | |
better to move it to other process so that it can be running on other CPU core""" | |
TaskResult = namedtuple('TaskResult', ['exc_return', 'exc_info'], defaults=[None, None]) | |
class Consumer(multiprocessing.Process): | |
def __init__(self, cls, task_queue, result_queue, args=(), kwargs={}): | |
multiprocessing.Process.__init__(self) | |
self.task_queue = task_queue | |
self.result_queue = result_queue | |
self.cls = cls | |
self.args = args | |
self.kwargs = kwargs | |
def run(self): | |
inst = self.cls(*self.args, **self.kwargs) | |
while True: | |
next_task, args, kwargs = self.task_queue.get() | |
if next_task is None: | |
self.result_queue.put(None) | |
self.task_queue.task_done() | |
break | |
try: | |
next_task = getattr(inst, next_task) | |
if callable(next_task): | |
answer = next_task(*args, **kwargs) | |
else: | |
answer = next_task | |
tr = TaskResult(answer, None) | |
except: | |
exc_type, exc_value, exc_traceback = sys.exc_info() | |
exc_format_tb = traceback.format_tb(exc_traceback) | |
exc_info = (exc_type, exc_value, exc_format_tb) | |
tr = TaskResult(None, exc_info) | |
self.task_queue.task_done() | |
self.result_queue.put(tr) | |
class MpWrapper: | |
def __init__(self, cls, args=(), kwargs={}): | |
self.tasks = multiprocessing.JoinableQueue() | |
self.results = multiprocessing.Queue() | |
self.cls = cls | |
self.consumer = Consumer(cls, self.tasks, self.results, args, kwargs) | |
def __enter__(self): | |
self.consumer.start() | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
assert self.consumer.is_alive() | |
self.tasks.put((None, None, None)) | |
self.results.get() | |
self.tasks.join() | |
self.consumer.join() | |
def __getattr__(self, item): | |
if hasattr(self.cls, item) and callable(getattr(self.cls, item)): | |
def wrapper(*args, **kwargs): | |
assert self.consumer.is_alive() | |
self.tasks.put((item, args, kwargs)) | |
tr = self.results.get() | |
if tr.exc_info: | |
exc_type, exc_value, exc_format_tb = tr.exc_info | |
raise exc_type("".join(exc_format_tb)) | |
else: | |
return tr.exc_return | |
return wrapper | |
else: | |
assert self.consumer.is_alive() | |
self.tasks.put((item, None, None)) | |
tr = self.results.get() | |
if tr.exc_info: | |
exc_type, exc_value, exc_format_tb = tr.exc_info | |
raise exc_type("".join(exc_format_tb)) | |
else: | |
return tr.exc_return | |
class MpTest: | |
def __init__(self, text=""): | |
self.io = text | |
def call_test(self, po, something=None): | |
if something: | |
print(po, something) | |
else: | |
print(self.io) | |
if __name__ == "__main__": | |
with MpWrapper(MpTest) as ai: | |
ai.call_test("hi,", something="hello, world") | |
ai.fail_test("Test") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment