Created
February 8, 2017 20:46
-
-
Save anilshanbhag/99d1db7f59e7a92a7d1e6429a2c54f95 to your computer and use it in GitHub Desktop.
Preloaded Mnist
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import sys | |
import time | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
from tensorflow.examples.tutorials.mnist import mnist | |
import numpy as np | |
import threading | |
from nvl_transformer import NVLTransformer, load_nvl_module | |
# Basic model parameters as external flags. | |
FLAGS = None | |
def run_training(): | |
"""Train MNIST for a number of epochs.""" | |
# Get the sets of images and labels for training, validation, and | |
# test on MNIST. | |
data_sets = input_data.read_data_sets(FLAGS.train_dir) | |
# Tell TensorFlow that the model will be built into the default Graph. | |
with tf.Graph().as_default(): | |
with tf.name_scope('input'): | |
# Input data | |
images_initializer = tf.placeholder( | |
dtype=data_sets.train.images.dtype, | |
shape=data_sets.train.images.shape) | |
labels_binary = np.array([1.0 if x == 0 else -1.0 for x in data_sets.train.labels], dtype=np.float32) | |
labels_initializer = tf.placeholder( | |
dtype=tf.float32, | |
shape=data_sets.train.labels.shape) | |
input_images = tf.Variable( | |
images_initializer, trainable=False, collections=[]) | |
input_labels = tf.Variable( | |
labels_initializer, trainable=False, collections=[]) | |
image, label = tf.train.slice_input_producer( | |
[input_images, input_labels], num_epochs=FLAGS.num_epochs) | |
X, Y = tf.train.batch( | |
[image, label], batch_size=FLAGS.batch_size) | |
# create a shared variable for the weight matrix | |
W = tf.Variable(tf.random_normal([784]), name="weights") | |
# Build the summary operation based on the TF collection of Summaries. | |
summary_op = tf.summary.merge_all() | |
# # Create a saver for writing training checkpoints. | |
saver = tf.train.Saver() | |
X_W = tf.reduce_sum(tf.mul(X, W), reduction_indices=1) | |
Y_X_W = tf.mul(Y, X_W) | |
cost = tf.reduce_mean(tf.log(tf.add(tf.exp(tf.neg(Y_X_W)), tf.constant(1, dtype=tf.float32)))) | |
# construct an optimizer to minimize cost and fit line to my data | |
train_op = tf.train.GradientDescentOptimizer(0.02).minimize(cost) | |
if FLAGS.nvl: | |
# Load the NVL module into tensorflow. | |
nvl_module = load_nvl_module(tf) | |
# Transform the pred to use NVL. | |
transformer = NVLTransformer(nvl_module) | |
transformer.transform(tf) | |
# NoOp -> ApplyGradientDescent -> ControlDep_1 | |
comp_tensor = train_op.control_inputs[0].inputs[2] | |
# comp_tensor = tf.get_default_graph().get_operations()[81].outputs[0] | |
print("Converting ", comp_tensor) | |
nvl_tensor = transformer.convert_to_nvl(comp_tensor) | |
# Create the op for initializing variables. | |
init_op = tf.global_variables_initializer() | |
# Create a session for running Ops on the Graph. | |
sess = tf.Session() | |
# Run the Op to initialize the variables. | |
sess.run(init_op) | |
#sess.run(tf.initialize_local_variables()) | |
sess.run(tf.local_variables_initializer()) | |
sess.run(input_images.initializer, | |
feed_dict={images_initializer: data_sets.train.images}) | |
sess.run(input_labels.initializer, | |
feed_dict={labels_initializer: labels_binary}) | |
# Instantiate a SummaryWriter to output summaries and the Graph. | |
summary_writer = tf.train.SummaryWriter(FLAGS.log_dir, sess.graph) | |
# Start input enqueue threads. | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
start = time.time() | |
workers = [] | |
NUM_EPOCHS = 4 | |
BATCH_SIZE = 250 | |
NUM_POINTS = 55000 | |
NUM_THREADS = 1 | |
def thread_body(): | |
for it in xrange(0, int(NUM_EPOCHS/NUM_THREADS)): | |
for i in xrange(0, int(NUM_POINTS), BATCH_SIZE): | |
sess.run(train_op) | |
for i in range(NUM_THREADS): | |
t = threading.Thread(target = thread_body) | |
t.start() | |
workers.append(t) | |
for t in workers: | |
t.join() | |
tt = time.time() - start | |
print(sess.run(W)) # It should be something around 2 | |
print("Time taken : ", tt , " seconds") | |
coord.request_stop() | |
# Wait for threads to finish. | |
coord.join(threads) | |
sess.close() | |
def main(_): | |
run_training() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--learning_rate', | |
type=float, | |
default=0.01, | |
help='Initial learning rate.' | |
) | |
parser.add_argument( | |
'--num_epochs', | |
type=int, | |
default=4, | |
help='Number of epochs to run trainer.' | |
) | |
parser.add_argument( | |
'--batch_size', | |
type=int, | |
default=250, | |
help='Batch size. Must divide evenly into the dataset sizes.' | |
) | |
parser.add_argument( | |
'--train_dir', | |
type=str, | |
default='/tmp/data', | |
help='Directory to put the training data.' | |
) | |
parser.add_argument( | |
'--log_dir', | |
type=str, | |
default='preloaded/', | |
help='Directory to put the summary logs.' | |
) | |
parser.add_argument( | |
'--fake_data', | |
default=False, | |
help='If true, uses fake data for unit testing.', | |
action='store_true' | |
) | |
parser.add_argument( | |
'--nvl', | |
type=bool, | |
default=False, | |
help='If true, uses NVL transformer.', | |
) | |
FLAGS, unparsed = parser.parse_known_args() | |
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment