Last active
August 29, 2015 14:07
-
-
Save mccutchen/7100057aa91d167cc048 to your computer and use it in GitHub Desktop.
flexible batching of sequences in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gen_batches(xs, size): | |
""" | |
Given a sequence xs and a batch size, yield batches from the sequence as | |
lists of length size, where the last batch might be smaller than the | |
rest. | |
>>> list(gen_batches(range(9), 3)) | |
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] | |
>>> list(gen_batches(range(11), 3)) | |
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]] | |
Also works with sequences that don't have a known size: | |
>>> list(gen_batches(xrange(9), 3)) | |
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] | |
>>> import itertools | |
>>> xs = itertools.cycle('abcd') | |
>>> list(itertools.islice(gen_batches(xs, 3), 3)) | |
[['a', 'b', 'c'], ['d', 'a', 'b'], ['c', 'd', 'a']] | |
""" | |
assert size > 0 | |
acc = [] | |
for i, x in enumerate(xs): | |
if i and i % size == 0: | |
yield acc | |
acc = [] | |
acc.append(x) | |
if acc: | |
yield acc | |
def gen_overlapping_batches(xs, size, overlap=0.0): | |
"""Given a sequence xs and a batch size, yield batches from the sequence as | |
lists of length size, where the last batch might be smaller than the | |
rest. | |
If an overlap percentage is given, each batch will share that percentage | |
of elements with the previous and next batch. | |
For example, with no overlap: | |
>>> list(gen_overlapping_batches(range(10), 4)) | |
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] | |
And with 25% overlap: | |
>>> list(gen_overlapping_batches(range(10), 4, 0.25)) | |
[[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]] | |
This should even work well with infinitely long generators: | |
>>> import itertools | |
>>> xs = itertools.cycle('abcd') | |
>>> list(itertools.islice(gen_overlapping_batches(xs, 4, 0.25), 3)) | |
[['a', 'b', 'c', 'd'], ['d', 'a', 'b', 'c'], ['c', 'd', 'a', 'b']] | |
""" | |
assert size > 0 | |
assert 0 <= overlap <= 1 | |
offset = int(size * overlap) | |
acc = [] | |
for i, x in enumerate(xs): | |
if i and len(acc) % size == 0: | |
yield acc | |
acc = acc[-offset:] if offset else [] | |
acc.append(x) | |
if acc: | |
yield acc | |
if __name__ == '__main__': | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment