Skip to content

Instantly share code, notes, and snippets.

@zhihongliuus
Last active October 5, 2020 17:19
Show Gist options
  • Save zhihongliuus/e1e6240329d954a15521de410bd4054b to your computer and use it in GitHub Desktop.
Save zhihongliuus/e1e6240329d954a15521de410bd4054b to your computer and use it in GitHub Desktop.
import sys
import traceback
import multiprocessing
from collections import namedtuple
"""Implement a wrapper to access instance which is running in separated process.
The reason to implement it because due to GIL, a python interpreter can only run on
one CPU core. If one task is running with heavy CPU usage for long time, it is
better to move it to other process so that it can be running on other CPU core"""
TaskResult = namedtuple('TaskResult', ['exc_return', 'exc_info'], defaults=[None, None])
class Consumer(multiprocessing.Process):
def __init__(self, cls, task_queue, result_queue, args=(), kwargs={}):
multiprocessing.Process.__init__(self)
self.task_queue = task_queue
self.result_queue = result_queue
self.cls = cls
self.args = args
self.kwargs = kwargs
def run(self):
inst = self.cls(*self.args, **self.kwargs)
while True:
next_task, args, kwargs = self.task_queue.get()
if next_task is None:
self.result_queue.put(None)
self.task_queue.task_done()
break
try:
next_task = getattr(inst, next_task)
if callable(next_task):
answer = next_task(*args, **kwargs)
else:
answer = next_task
tr = TaskResult(answer, None)
except:
exc_type, exc_value, exc_traceback = sys.exc_info()
exc_format_tb = traceback.format_tb(exc_traceback)
exc_info = (exc_type, exc_value, exc_format_tb)
tr = TaskResult(None, exc_info)
self.task_queue.task_done()
self.result_queue.put(tr)
class MpWrapper:
def __init__(self, cls, args=(), kwargs={}):
self.tasks = multiprocessing.JoinableQueue()
self.results = multiprocessing.Queue()
self.cls = cls
self.consumer = Consumer(cls, self.tasks, self.results, args, kwargs)
def __enter__(self):
self.consumer.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
assert self.consumer.is_alive()
self.tasks.put((None, None, None))
self.results.get()
self.tasks.join()
self.consumer.join()
def __getattr__(self, item):
if hasattr(self.cls, item) and callable(getattr(self.cls, item)):
def wrapper(*args, **kwargs):
assert self.consumer.is_alive()
self.tasks.put((item, args, kwargs))
tr = self.results.get()
if tr.exc_info:
exc_type, exc_value, exc_format_tb = tr.exc_info
raise exc_type("".join(exc_format_tb))
else:
return tr.exc_return
return wrapper
else:
assert self.consumer.is_alive()
self.tasks.put((item, None, None))
tr = self.results.get()
if tr.exc_info:
exc_type, exc_value, exc_format_tb = tr.exc_info
raise exc_type("".join(exc_format_tb))
else:
return tr.exc_return
class MpTest:
def __init__(self, text=""):
self.io = text
def call_test(self, po, something=None):
if something:
print(po, something)
else:
print(self.io)
if __name__ == "__main__":
with MpWrapper(MpTest) as ai:
ai.call_test("hi,", something="hello, world")
ai.fail_test("Test")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment