Last active
June 23, 2019 23:27
-
-
Save elgehelge/6c7cd65bd08b71b898fb54eb13ed6f98 to your computer and use it in GitHub Desktop.
Minimal example - tensorflow with external generator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# Data source | |
def data_generator(start, end): | |
for x, y in zip(range(start, end), range(start, end)): | |
print(x, y) | |
yield x, y | |
# TF dataset | |
def input_fn(data_getter): | |
dataset = (tf.data.Dataset.from_generator( | |
generator=lambda: data_getter, | |
output_types=(tf.float32), | |
) | |
.repeat() | |
.make_one_shot_iterator().get_next() | |
) | |
return dataset[0], dataset[1] | |
def model_fn(features, labels, mode): | |
var = tf.Variable(0, dtype=tf.float32) | |
prediction = features + var | |
loss = prediction - labels | |
loss.set_shape([]) | |
return tf.estimator.EstimatorSpec( | |
mode=mode, | |
predictions=prediction, | |
loss=loss, | |
train_op=tf.contrib.layers.optimize_loss( | |
loss=loss, | |
global_step=tf.train.get_global_step(), | |
optimizer=tf.train.AdamOptimizer, | |
learning_rate=0.01, | |
), | |
) | |
def run(): | |
tf.logging.set_verbosity(tf.logging.DEBUG) | |
# NB! External data source as generator (this is what we should avoid!) | |
train_data_gen = data_generator(start=0, end=5) | |
eval_data_gen = data_generator(start=100, end=105) | |
estimator = tf.estimator.Estimator(model_fn=model_fn) | |
train_spec = tf.estimator.TrainSpec( | |
input_fn=lambda: input_fn(train_data_gen)) | |
eval_spec = tf.estimator.EvalSpec( | |
input_fn=lambda: input_fn(eval_data_gen), start_delay_secs=0) | |
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) | |
if __name__ == '__main__': | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This looks like a decent minimal example. Could you elaborate on your comment below. Thank you.
# NB! External data source as generator (this is what we should avoid!)