Created
May 12, 2021 18:11
-
-
Save rreece/cfeda8ddb1dce3a2c1fdb7096d59adb9 to your computer and use it in GitHub Desktop.
Writes float16 data to a tfrecord as raw bytes and reads it back.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Writes float16 data to a tfrecord as raw bytes and reads it back. | |
Based on: | |
https://stackoverflow.com/questions/40184812/tensorflow-is-it-possible-to-store-tf-record-sequence-examples-as-float16 | |
""" | |
import argparse | |
import numpy as np | |
import tensorflow as tf | |
def _write_tfrecord(data_np): | |
with tf.io.TFRecordWriter('data.tfrecord') as writer: | |
# encode the data in a dictionary of features | |
data = {'x': tf.train.Feature( | |
# the feature has a type ByteList | |
bytes_list=tf.train.BytesList( | |
# encode the data into bytes | |
value=[data_np.tobytes()]))} | |
# create a example from the features | |
example = tf.train.Example(features=tf.train.Features(feature=data)) | |
# write the example to a TFRecord file | |
writer.write(example.SerializeToString()) | |
def _parse_tfrecord(example_proto): | |
# describe how the TFRecord example will be interpreted | |
features_format = { | |
'x': tf.io.FixedLenFeature((), tf.string) | |
} | |
# parse the example (dict of features) from the TFRecord | |
parsed_features = tf.io.parse_single_example(example_proto, features_format) | |
# decode the bytes as float16 array | |
features = { k : tf.io.decode_raw(parsed_features[k], tf.float16) for k in features_format.keys() } | |
return features | |
def input_fn(): | |
# read the dataset | |
dataset = tf.data.TFRecordDataset(['data.tfrecord']) | |
# parse each example of the dataset | |
dataset = dataset.map(_parse_tfrecord) | |
return dataset | |
def main(): | |
# generate the data | |
x_np = np.array(np.random.rand(10), dtype=np.float16) | |
_write_tfrecord(x_np) | |
ds = input_fn() | |
for i_batch, batch in enumerate(ds): | |
print('DEBUG: i_batch = %i' % (i_batch), flush=True) | |
x = batch['x'].numpy() | |
print('DEBUG: x = ', x, flush=True) | |
print('DEBUG: allclose = ', np.allclose(x, x_np), flush=True) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment