Created
March 28, 2017 01:10
-
-
Save DjangoPeng/07d8bfb8f0c412318e4fb6473ec2f56b to your computer and use it in GitHub Desktop.
[Tensorflow][issue#8687] OutOfRangeError reproduce code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import os.path | |
import sys | |
import time | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import mnist | |
# Basic model parameters as external flags. | |
FLAGS = None | |
# Constants used for dealing with the files, matches convert_to_records. | |
TRAIN_FILE = 'train.tfrecords' | |
VALIDATION_FILE = 'validation.tfrecords' | |
def read_and_decode(filename_queue): | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
features = tf.parse_single_example( | |
serialized_example, | |
# Defaults are not specified since both keys are required. | |
features={ | |
'image_raw': tf.FixedLenFeature([], tf.string), | |
'label': tf.FixedLenFeature([], tf.int64), | |
}) | |
image = tf.decode_raw(features['image_raw'], tf.uint8) | |
image.set_shape([mnist.IMAGE_PIXELS]) | |
# Convert from [0, 255] -> [-0.5, 0.5] floats. | |
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 | |
# Convert label from a scalar uint8 tensor to an int32 scalar. | |
label = tf.cast(features['label'], tf.int32) | |
return image, label | |
def inputs(train, batch_size, num_epochs): | |
if not num_epochs: num_epochs = None | |
filename = os.path.join(FLAGS.train_dir, | |
TRAIN_FILE if train else VALIDATION_FILE) | |
with tf.name_scope('input'): | |
filename_queue = tf.train.string_input_producer( | |
[filename], num_epochs=num_epochs) | |
# Even when reading in multiple threads, share the filename | |
# queue. | |
image, label = read_and_decode(filename_queue) | |
images, sparse_labels = tf.train.shuffle_batch( | |
[image, label], batch_size=batch_size, num_threads=2, | |
capacity=1000 + 3 * batch_size, | |
# Ensures a minimum amount of shuffling of examples. | |
min_after_dequeue=1000) | |
return images, sparse_labels | |
def run_training(): | |
"""Train MNIST for a number of steps.""" | |
# Tell TensorFlow that the model will be built into the default Graph. | |
with tf.Graph().as_default(): | |
# Input images and labels. | |
images, labels = inputs(train=True, batch_size=FLAGS.batch_size, | |
num_epochs=FLAGS.num_epochs) | |
# Build a Graph that computes predictions from the inference model. | |
logits = mnist.inference(images, | |
FLAGS.hidden1, | |
FLAGS.hidden2) | |
# Add to the Graph the loss calculation. | |
loss = mnist.loss(logits, labels) | |
# Add to the Graph operations that train the model. | |
train_op = mnist.training(loss, FLAGS.learning_rate) | |
# The op for initializing the variables. | |
init_op = tf.group(tf.global_variables_initializer(), | |
tf.local_variables_initializer()) | |
# Create a session for running operations in the Graph. | |
sess = tf.Session() | |
# Initialize the variables (the trained variables and the | |
# epoch counter). | |
sess.run(init_op) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(coord=coord, sess=sess) | |
try: | |
for step in xrange(FLAGS.max_train_steps): | |
start_time = time.time() | |
_, loss_value = sess.run([train_op, loss]) | |
duration = time.time() - start_time | |
# Print an overview fairly often. | |
if step % 100 == 0: | |
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, | |
duration)) | |
if coord.should_stop(): | |
break | |
except tf.errors.OutOfRangeError: | |
print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) | |
finally: | |
coord.request_stop() | |
coord.join(threads) | |
sess.close() | |
def main(_): | |
run_training() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--max_train_steps', | |
type=int, | |
default=500, | |
help='Max train steps.' | |
) | |
parser.add_argument( | |
'--learning_rate', | |
type=float, | |
default=0.01, | |
help='Initial learning rate.' | |
) | |
parser.add_argument( | |
'--num_epochs', | |
type=int, | |
default=2, | |
help='Number of epochs to run trainer.' | |
) | |
parser.add_argument( | |
'--hidden1', | |
type=int, | |
default=128, | |
help='Number of units in hidden layer 1.' | |
) | |
parser.add_argument( | |
'--hidden2', | |
type=int, | |
default=32, | |
help='Number of units in hidden layer 2.' | |
) | |
parser.add_argument( | |
'--batch_size', | |
type=int, | |
default=100, | |
help='Batch size.' | |
) | |
parser.add_argument( | |
'--train_dir', | |
type=str, | |
# Fix me right ! | |
default='hdfs://10.0.0.1/tfrecords/mnist-data', | |
help='Directory with the training data.' | |
) | |
FLAGS, unparsed = parser.parse_known_args() | |
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment