Created
January 1, 2020 04:30
-
-
Save dschwertfeger/1593c71bbec6114441890e66d4b73b02 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def _bytestring_feature(list_of_bytestrings): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings)) | |
def _int_feature(list_of_ints): # int64 | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints)) | |
def _float_feature(list_of_floats): # float32 | |
return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats)) | |
def to_tfrecord(audio, label): | |
feature = { | |
'audio': _float_feature(audio), # audio is a list of floats | |
'label': _int_feature([label]) # wrap label index in list | |
} | |
# Example is a flexible message type that contains key-value pairs, | |
# where each key maps to a Feature message. Here, each Example contains | |
# two features: A FloatList for the decoded audio data and an Int64List | |
# containing the corresponding label's index. | |
return tf.train.Example(features=tf.train.Features(feature=feature)) | |
if __name__ == "__main__": | |
# Assume a dataset of [audio, label] pairs | |
dataset = load_dataset() | |
with tf.io.TFRecordWriter('train.tfrecord') as out: | |
# Iterate over [audio, label] pairs in dataset | |
for audio, label in dataset: | |
# Encode [audio, label] pair to TFRecord format | |
example = to_tfrecord(audio, label) | |
# Write serialized example to TFRecord file | |
out.write(example.SerializeToString()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment