Skip to content

Instantly share code, notes, and snippets.

@wohlbier
Forked from elgehelge/mnist_estimator.py
Last active February 1, 2021 19:10
Show Gist options
  • Save wohlbier/9f84559af5f8db0fdcd7bcfe99fc8fe2 to your computer and use it in GitHub Desktop.
Save wohlbier/9f84559af5f8db0fdcd7bcfe99fc8fe2 to your computer and use it in GitHub Desktop.
Example using TensorFlow Estimator, Experiment & Dataset on MNIST data.
"""Script to illustrate usage of tf.estimator.Estimator in TF v2.3.0"""
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Show debugging output
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
# Set default flags for the output directories
FLAGS = tf.compat.v1.app.flags.FLAGS
tf.compat.v1.app.flags.DEFINE_string(
name='model_dir', default='./mnist_training',
help='Output directory for model and training stats.')
tf.compat.v1.app.flags.DEFINE_string(
name='data_dir', default='./mnist_data',
help='Directory to download the data to.')
class images_labels:
def __init__(self, images, labels):
self.images = images
self.labels = labels
class Mnist_data:
def __init__(self, X_train, y_train, X_test, y_test):
self.train = images_labels(X_train, y_train)
self.test = images_labels(X_test, y_test)
# Define and run experiment ###############################
def run_experiment(argv=None):
"""Run the training experiment."""
print("run_experiment")
# Define model parameters
params = {
'learning_rate': 0.002,
'n_classes': 10,
'train_steps': 50000,
'min_eval_frequency': 100
}
# Set the run_config and the directory to save the model and stats
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(model_dir=FLAGS.model_dir)
# You can change a subset of the run_config properties as
#run_config = run_config.replace(
# save_checkpoints_steps=params.min_eval_frequency)
# Define the mnist classifier
estimator = get_estimator(run_config, params)
# Setup data loaders
#mnist = mnist_data.read_data_sets(FLAGS.data_dir, one_hot=False)
#mnist = tf.keras.datasets.mnist.load_data()
(X_train, l_train), (X_test, l_test) = tf.keras.datasets.mnist.load_data()
# collapse last dimension and normalize
X_train = X_train.reshape((X_train.shape[0],-1)).astype(np.float32) / 255.0
X_test = X_test.reshape((X_test.shape[0],-1)).astype(np.float32) / 255.0
# make one hot
y_train = np.zeros((l_train.shape[0], l_train.max()+1), dtype=np.float32)
y_train[np.arange(l_train.shape[0]), l_train] = 1
y_test = np.zeros((l_test.shape[0], l_test.max()+1), dtype=np.float32)
y_test[np.arange(l_test.shape[0]), l_test] = 1
# mnist data class
mnist = Mnist_data(X_train, y_train, X_test, y_test)
train_input_fn, train_input_hook = get_train_inputs(
batch_size=1, mnist_data=mnist)
eval_input_fn, eval_input_hook = get_test_inputs(
batch_size=1, mnist_data=mnist)
# Define the experiment
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, # First-class function
max_steps=params['train_steps'], # Minibatch steps
hooks=[train_input_hook], # Hooks for training
)
eval_spec = tf.estimator.EvalSpec(
input_fn=eval_input_fn, # First-class function
steps=None, # Use evaluation feeder until its empty
hooks=[eval_input_hook], # Hooks for evaluation
)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
# Define model ############################################
def get_estimator(run_config, params):
"""Return the model as a Tensorflow Estimator object.
Args:
run_config (RunConfig): Configuration for Estimator run.
params (HParams): hyperparameters.
"""
return tf.estimator.Estimator(
model_fn=model_fn, # First-class function
params=params, # HParams
config=run_config # RunConfig
)
def model_fn(features, labels, mode, params):
"""Model function used in the estimator.
Args:
features (Tensor): Input features to the model.
labels (Tensor): Labels tensor for training and evaluation.
mode (ModeKeys): Specifies if training, evaluation or prediction.
params (HParams): hyperparameters.
Returns:
(EstimatorSpec): Model to be run by Estimator.
"""
print("model_fn")
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
logits = tf.nn.softmax(tf.matmul(features, W) + b) # Softmax
is_training = mode == tf.estimator.ModeKeys.TRAIN
# Define model's architecture
#logits = architecture(features, is_training=is_training)
predictions = tf.argmax(input=logits, axis=-1)
# Loss, training and eval operations are not needed during inference.
loss = None
train_op = None
eval_metric_ops = {}
if mode != tf.estimator.ModeKeys.PREDICT:
loss = tf.compat.v1.losses.softmax_cross_entropy(
tf.cast(labels, tf.int32),
logits=logits
)
train_op = get_train_op_fn(loss, params)
#eval_metric_ops = get_eval_metric_ops(labels, predictions)
eval_metric_ops = get_eval_metric_ops(labels, logits)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops
)
def get_train_op_fn(loss, params):
"""Get the training Op.
Args:
loss (Tensor): Scalar Tensor that represents the loss function.
params (HParams): Hyperparameters (needs to have `learning_rate`)
Returns:
Training Op
"""
print("get_train_op_fn")
optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=params["learning_rate"],
name="Adam"
)
return optimizer.minimize(
loss=loss, global_step=tf.compat.v1.train.get_global_step()
)
def get_eval_metric_ops(labels, predictions):
"""Return a dict of the evaluation Ops.
Args:
labels (Tensor): Labels tensor for training and evaluation.
predictions (Tensor): Predictions Tensor.
Returns:
Dict of metric results keyed by name.
"""
print("get_eval_metrics_op")
return {
'Accuracy': tf.compat.v1.metrics.accuracy(
labels=labels,
predictions=predictions,
name='accuracy')
}
#def architecture(inputs, is_training, scope='MnistConvNet'):
# """Return the output operation following the network architecture.
#
# Args:
# inputs (Tensor): Input Tensor
# num_classes (int): Number of classes
# is_training (bool): True iff in training mode
# scope (str): Name of the scope of the architecture
#
# Returns:
# Logits output Op for the network.
# """
# with tf.compat.v1.variable_scope(scope):
# with slim.arg_scope(
# [slim.conv2d, slim.fully_connected],
# weights_initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")):
# net = slim.conv2d(inputs, 20, [5, 5], padding='VALID',
# scope='conv1')
# net = slim.max_pool2d(net, 2, stride=2, scope='pool2')
# net = slim.conv2d(net, 40, [5, 5], padding='VALID',
# scope='conv3')
# net = slim.max_pool2d(net, 2, stride=2, scope='pool4')
# net = tf.reshape(net, [-1, 4 * 4 * 40])
# net = slim.fully_connected(net, 256, scope='fn5')
# net = slim.dropout(net, is_training=is_training,
# scope='dropout5')
# net = slim.fully_connected(net, 256, scope='fn6')
# net = slim.dropout(net, is_training=is_training,
# scope='dropout6')
# net = slim.fully_connected(net, 10, scope='output',
# activation_fn=None)
# return net
# print("architecture")
# Define data loaders #####################################
class IteratorInitializerHook(tf.estimator.SessionRunHook):
"""Hook to initialise data iterator after Session is created."""
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None
def after_create_session(self, session, coord):
"""Initialise the iterator after the session has been created."""
self.iterator_initializer_func(session)
# Define the training inputs
def get_train_inputs(batch_size, mnist_data):
"""Return the input function to get the training data.
Args:
batch_size (int): Batch size of training iterator that is returned
by the input function.
mnist_data (Object): Object holding the loaded mnist data.
Returns:
(Input function, IteratorInitializerHook):
- Function that returns (features, labels) when called.
- Hook to initialise input iterator.
"""
iterator_initializer_hook = IteratorInitializerHook()
def train_inputs():
"""Returns training set as Operations.
Returns:
(features, labels) Operations that iterate over the dataset
on every evaluation
"""
print("train_inputs")
with tf.compat.v1.name_scope('Training_data'):
# Get Mnist data
#images = mnist_data.train.images.reshape([-1, 28, 28, 1])
images = mnist_data.train.images.reshape([-1, 784])
labels = mnist_data.train.labels
# Define placeholders
images_placeholder = tf.compat.v1.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.compat.v1.placeholder(
labels.dtype, labels.shape)
# Build dataset iterator
dataset = tf.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder)
)
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
next_example, next_label = iterator.get_next()
# Set runhook to initialize iterator
iterator_initializer_hook.iterator_initializer_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
# Return batched (features, labels)
return next_example, next_label
# Return function and hook
return train_inputs, iterator_initializer_hook
def get_test_inputs(batch_size, mnist_data):
"""Return the input function to get the test data.
Args:
batch_size (int): Batch size of training iterator that is returned
by the input function.
mnist_data (Object): Object holding the loaded mnist data.
Returns:
(Input function, IteratorInitializerHook):
- Function that returns (features, labels) when called.
- Hook to initialise input iterator.
"""
iterator_initializer_hook = IteratorInitializerHook()
def test_inputs():
"""Returns training set as Operations.
Returns:
(features, labels) Operations that iterate over the dataset
on every evaluation
"""
with tf.compat.v1.name_scope('Test_data'):
# Get Mnist data
#images = mnist_data.test.images.reshape([-1, 28, 28, 1])
images = mnist_data.test.images.reshape([-1, 784])
labels = mnist_data.test.labels
# Define placeholders
images_placeholder = tf.compat.v1.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.compat.v1.placeholder(
labels.dtype, labels.shape)
# Build dataset iterator
dataset = tf.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
dataset = dataset.batch(batch_size)
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
next_example, next_label = iterator.get_next()
# Set runhook to initialize iterator
iterator_initializer_hook.iterator_initializer_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
return next_example, next_label
# Return function and hook
return test_inputs, iterator_initializer_hook
# Run script ##############################################
if __name__ == "__main__":
tf.compat.v1.app.run(
main=run_experiment
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment