Created
April 1, 2018 09:08
-
-
Save louiskirsch/30d9dca2ebf60c303f2edadc64f96c01 to your computer and use it in GitHub Desktop.
A tensorflow ring buffer implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RingBuffer: | |
def __init__(self, scope_name, components, size): | |
""" | |
Create a new ring buffer of size `size`. | |
Each item in the ring buffer is a tuple of variables of size `len(components)`. | |
:param scope_name: A scope name for the newly created variables | |
:param components: Defines the type of items in the buffer. An iterable of tuples (name: str, shape: Iterable, dtype) | |
:param size: The maximum size of the buffer | |
""" | |
self.size = size | |
with tf.variable_scope(scope_name, initializer=tf.zeros_initializer()): | |
self.components = [tf.get_variable(name, [size] + list(shape), dtype) for name, shape, dtype in components] | |
self.offset = tf.get_variable('offset', shape=[], dtype=tf.int32) | |
def insert(self, tensors): | |
elem_count = tensors[0].shape.as_list()[0] | |
ops = [] | |
for tensor, component in zip(tensors, self.components): | |
assert tensor.shape.as_list()[0] == elem_count | |
# Fill the tail of the buffer | |
start = self.offset | |
end = tf.minimum(self.size, self.offset + elem_count) | |
fill_count = end - start | |
ops.append(component[start:end].assign(tensor[:fill_count])) | |
# Fill the front of the buffer if elements are still left | |
end = elem_count - fill_count | |
ops.append(component[:end].assign(tensor[fill_count:])) | |
with tf.control_dependencies(ops): | |
ops.append(self.offset.assign((self.offset + elem_count) % self.size)) | |
return tf.group(*ops) | |
def sample(self, count): | |
indices = tf.random_shuffle(tf.range(self.size))[:count] | |
return [tf.gather(component, indices) for component in self.components] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment