Created
November 2, 2016 08:27
-
-
Save rmeertens/4946f459d0e6becefee1239a9abd1372 to your computer and use it in GitHub Desktop.
Whoo. Ineffective sorting algorithm!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import tempfile | |
import random | |
vocabulary_size = 250 # amount of numbers our algorithm can handle | |
batch_size = 256 | |
length_vector = 10 # length of the sequences we try to sort | |
training_epochs = 10000 | |
learning_rate = 0.05 | |
momentum = 0.9 | |
embedding_dim = 50 | |
memory_dim = 1024 | |
sess = tf.InteractiveSession() | |
seq_length = length_vector | |
vocab_size = vocabulary_size+1 | |
enc_inp = [tf.placeholder(tf.int32, shape=(None,), name="inp%i" % t) for t in range(seq_length)] | |
# Names used as keys in the decoders dictionary | |
names = ["original","sorted","reversed","sorted_reversed"] | |
# Create placeholder for the labels for each function we want to use with our decoders | |
labels_original = [tf.placeholder(tf.int32, shape=(None,), name="labels_original%i" % t) for t in range(seq_length)] | |
labels_sorted = [tf.placeholder(tf.int32, shape=(None,), name="labels_sorted%i" % t) for t in range(seq_length)] | |
labels_reversed = [tf.placeholder(tf.int32, shape=(None,), name="labels_reversed%i" % t) for t in range(seq_length)] | |
labels_sorted_reversed = [tf.placeholder(tf.int32, shape=(None,), name="labels_sorted_reversed%i" % t) for t in range(seq_length)] | |
# Use the same weights for each of the labels | |
weights = [tf.ones_like(labels_t, dtype=tf.float32) for labels_t in labels_original] | |
# Decoder input: Let's not make it simple for our algorithm | |
# Let's feed the previous output to the input, even during the | |
# training phase. | |
# Note that training time can be reduced by not doing this | |
dec_inp = ([tf.zeros_like(enc_inp[0], dtype=np.int32)]*length_vector) | |
# Start with a fresh memory every time | |
prev_mem = tf.zeros((batch_size, memory_dim)) | |
# There is no good reason I took a GRU above other cells | |
cell = tf.nn.rnn_cell.GRUCell(memory_dim) | |
# Set the decoder inputs for each of the decoders | |
decoder_inputs_dictionary = {} | |
for name in names: | |
decoder_inputs_dictionary[name] = dec_inp | |
# Set the amount of symbols for each decoder | |
decoder_symbols_dictionary = {} | |
for name in names: | |
decoder_symbols_dictionary[name] = vocab_size | |
# This is where the magic happens! | |
# With one encoder we get multiple (in this case 4) decoders | |
(outputs_dict,state_dict) = tf.nn.seq2seq.one2many_rnn_seq2seq(enc_inp,decoder_inputs_dictionary, | |
cell, vocab_size, decoder_symbols_dictionary,embedding_size=embedding_dim,feed_previous=True) | |
# For printing and the loss function I take the outputs out of their dictionary | |
original_outputs = outputs_dict["original"] | |
sorted_outputs = outputs_dict["sorted"] | |
reversed_outputs = outputs_dict["reversed"] | |
sorted_reversed_outputs = outputs_dict["sorted_reversed"] | |
# Make the loss a combination of each of the loss for a sequence | |
loss = tf.nn.seq2seq.sequence_loss(original_outputs, labels_original, weights, vocab_size)+\ | |
tf.nn.seq2seq.sequence_loss(sorted_outputs, labels_sorted, weights, vocab_size)+ \ | |
tf.nn.seq2seq.sequence_loss(reversed_outputs, labels_reversed, weights, vocab_size)+\ | |
tf.nn.seq2seq.sequence_loss(sorted_reversed_outputs, labels_sorted_reversed, weights, vocab_size) | |
# Optimize the loss with a momentum optimizer | |
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum) | |
train_op = optimizer.minimize(loss) | |
# Init everything | |
sess.run(tf.initialize_all_variables()) | |
def train_batch(batch_size): | |
# Create random arrays as input | |
X = [np.random.choice(vocab_size-1, size=(seq_length,), replace=False) | |
for _ in range(batch_size)] | |
# output: same as input | |
Yoriginal = [a.copy() for a in X] | |
#output: reversed input | |
Yreversed = [] | |
for a in Yoriginal: | |
Yreversed.append(a[::-1]) | |
# output: sorted | |
Ysorted = [a.copy() for a in X] | |
for a in Ysorted: | |
a.sort() | |
# output: reversed sorted | |
Ysorted_reversed = [] | |
for a in Ysorted: | |
Ysorted_reversed.append(a[::-1]) | |
# Print the inputs and expected outputs | |
# For debugging and because it is interesting | |
print("Input: " + str(X[0])) | |
print("Expected same output: " + str(Yoriginal[0])) | |
print("Expected eversed output: "+ str(Yreversed[0])) | |
print("Expected sorted: " + str(Ysorted[0])) | |
print("Expected reversed sorted: " + str(Ysorted_reversed[0])) | |
# Transpose the arrays to get shape seq_len * batch_size | |
X = np.array(X).T | |
Yoriginal = np.array(Yoriginal).T | |
Yreversed = np.array(Yreversed).T | |
Ysorted = np.array(Ysorted).T | |
Ysorted_reversed = np.array(Ysorted_reversed).T | |
# Put the sequences in the feed dictionary | |
feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)} | |
feed_dict.update({labels_original[t]: Yoriginal[t] for t in range(seq_length)}) | |
feed_dict.update({labels_reversed[t]: Yreversed[t] for t in range(seq_length)}) | |
feed_dict.update({labels_sorted[t]: Ysorted[t] for t in range(seq_length)}) | |
feed_dict.update({labels_sorted_reversed[t]: Ysorted_reversed[t] for t in range(seq_length)}) | |
# Run the session and hopefully get the right lists as output | |
_, loss_t,out1,out2,out3,out4 = sess.run([train_op, loss,original_outputs,sorted_outputs,reversed_outputs,sorted_reversed_outputs], feed_dict) | |
to_print =[out1,out2,out3,out4] | |
# Print the output for fun and debudding | |
for index,o in enumerate(to_print): | |
first_sample = [] | |
for x in o: | |
first_sample.append(np.argmax(x[0])) | |
print(names[index] + " result: " + str(first_sample)) | |
return loss_t | |
# Run the training X epochs | |
for t in range(training_epochs): | |
loss_t = train_batch(batch_size) | |
print("Epoch "+ str(t) + ". Loss: " +str(loss_t)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment