Created
May 8, 2017 09:01
-
-
Save psycharo-zz/3f84c24c4666725ee3dbf5f55cd14aa0 to your computer and use it in GitHub Desktop.
custom multi-threading runner for tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import numpy as np | |
import tensorflow as tf | |
class FeedingRunner(object): | |
"""Takes care of feeding/dequeueing data into the queue | |
Based on tf.train.QueueRunner | |
""" | |
def __init__(self, generator, dtypes, shapes, names, num_threads, | |
queue_capacity): | |
""" | |
Args: | |
generator: generator that returns more data | |
dtypes: a list of types of inputs | |
shapes: a list of shapes of the inputs | |
names: a list of names of the inputs | |
num_threads: how many threads to have | |
queue_capacity: number of sample to keep in the examples queue | |
""" | |
assert len(dtypes) == len(shapes) == len(names) | |
self._generator = generator | |
self._num_threads = num_threads | |
self._queue = tf.FIFOQueue(queue_capacity, dtypes) | |
self._dtypes = dtypes | |
self._shapes = shapes | |
self._names = names | |
self._placeholders = [tf.placeholder(dtype, shape) | |
for dtype, shape in zip(dtypes, shapes) ] | |
self._enqueue_op = self._queue.enqueue(self._placeholders) | |
self._dequeue_op = self._queue.dequeue() | |
# dequeue returns list when there are multiple tensors, and | |
if type(self._dequeue_op) != list: | |
self._dequeue_op = [self._dequeue_op] | |
self._cancel_op = self._queue.close(cancel_pending_enqueues=True) | |
self._inputs = [] | |
for i, value in enumerate(self._dequeue_op): | |
value.set_shape(self._shapes[i]) | |
self._inputs.append(tf.identity(value, self._names[i])) | |
def _run(self, sess, coord): | |
"""Runs the cycle that feeds data into the queue""" | |
try: | |
for values in self._generator: | |
if coord and coord.should_stop(): | |
break | |
feed_dict = { key : value | |
for key, value in zip(self._placeholders, values) } | |
sess.run(self._enqueue_op, feed_dict) | |
except Exception as e: | |
if coord: | |
coord.request_stop(e) | |
def _close_on_stop(self, sess, cancel_op, coord): | |
"""Close the queue when the Coordinator requests stop. | |
Args: | |
sess: A Session. | |
cancel_op: The Operation to run. | |
coord: Coordinator. | |
""" | |
coord.wait_for_stop() | |
try: | |
sess.run(cancel_op) | |
except Exception as e: | |
tf.logging.vlog(1, 'Ignored exception: %s', str(e)) | |
def create_threads(self, sess, coord=None, daemon=False, start=False): | |
threads = [threading.Thread(target=self._run, args=(sess, coord)) | |
for i in range(self._num_threads)] | |
if coord: | |
threads.append(threading.Thread(target=self._close_on_stop, | |
args=(sess, self._cancel_op, coord))) | |
for t in threads: | |
if coord: | |
coord.register_thread(t) | |
if daemon: | |
t.daemon = True | |
if start: | |
t.start() | |
return threads | |
@property | |
def queue(self): | |
return self._queue | |
@property | |
def inputs(self): | |
return self._inputs | |
class RandomDataIterator(object): | |
"""Iterator for uniform-random sampling from the dataset""" | |
def __init__(self, filenames, readers, batch_size, replace=False): | |
""" | |
Args: | |
filenames: list of tuples of filenames | |
readers: list of (threadsafe) functions taking filename and returning | |
data | |
batch_size: int > 0 | |
replace: bool, whether to sample with replacement | |
""" | |
assert len(filenames) >= 1, len(filenames) == len(readers) | |
self.filenames = filenames | |
self.readers = readers | |
self.batch_size = batch_size | |
self.replace = replace | |
def __iter__(self): | |
return self | |
def __next__(self): | |
idxs = np.random.choice(len(self.filenames[0]), self.batch_size, self.replace) | |
batch = [] | |
for fid, reader in enumerate(self.readers): | |
batch.append([reader(self.filenames[fid][idx]) | |
for idx in idxs]) | |
return batch |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment