Created
July 30, 2016 14:03
-
-
Save siddMahen/e45174fbf60a4df174af4a5d95a293f1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import sys | |
import os | |
def read_and_decode(filename_queue): | |
reader = tf.TextLineReader() | |
_, record = reader.read(filename_queue) | |
return record | |
def inputs(filenames, batch_size, num_epochs): | |
with tf.name_scope('input'): | |
filename_queue = tf.train.string_input_producer( | |
filenames, num_epochs=num_epochs, shuffle=True) | |
line = read_and_decode(filename_queue) | |
min_after_dequeue = 10 | |
capacity = min_after_dequeue + 3*batch_size | |
line_batch = tf.train.shuffle_batch([line], batch_size=batch_size, | |
capacity=capacity, min_after_dequeue=min_after_dequeue, | |
allow_smaller_final_batch=True) | |
return line_batch | |
def train(run_name, filenames): | |
with tf.Graph().as_default(): | |
lines = inputs(filenames, batch_size=5, num_epochs=1) | |
v = tf.Variable(1.0) | |
init_op = tf.initialize_all_variables() | |
init_again = tf.initialize_local_variables() | |
sess = tf.Session() | |
saver = tf.train.Saver() | |
prev_step = 0 | |
ckpt = tf.train.get_checkpoint_state('.') | |
if ckpt and ckpt.model_checkpoint_path: | |
# Check if the run name matches ours | |
ending = ckpt.model_checkpoint_path.split('/')[-1].split('-') | |
alt_name = ending[1] | |
if alt_name == run_name: | |
prev_step = int(ending[2]) | |
saver.restore(sess, ckpt.model_checkpoint_path) | |
else: | |
sess.run(init_op) | |
else: | |
sess.run(init_op) | |
sess.run(init_again) | |
coord = tf.train.Coordinator() | |
ckpt_path = os.path.join('.', "model-" + run_name) | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
try: | |
step = prev_step | |
while not coord.should_stop(): | |
l = sess.run(lines) | |
for line in l: | |
print(line) | |
save_path = saver.save(sess, ckpt_path, global_step=step) | |
print('Model saved to %s' % save_path) | |
step += 1 | |
except tf.errors.OutOfRangeError: | |
print("Done training!") | |
save_path = saver.save(sess, ckpt_path, global_step=step) | |
print('Model saved to %s' % save_path) | |
finally: | |
coord.request_stop() | |
coord.join(threads) | |
sess.close() | |
if __name__ == '__main__': | |
run_name = sys.argv[1] | |
filenames = sys.argv[2:] | |
train(run_name, filenames) | |
# Usage: python train.py model_name input1.txt input2.txt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment