Created
November 7, 2017 00:21
-
-
Save cinjon/db8fb331f316d480a3bf40c992f4ea4b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_graph(self): | |
label = tf.one_hot(self.batch, 10*self._config.num_digits) | |
self.label = tf.argmax(label, -1) | |
num_digits = self._config.num_digits | |
num_binary_messages = self._config.num_binary_messages | |
# Speaker | |
with tf.variable_scope("A1"): | |
weights = tf.get_variable("embeddings", shape=(10*num_digits, self._config.embedding_size), | |
dtype=tf.float32, initializer=tf.orthogonal_initializer) | |
inputs = tf.nn.embedding_lookup(weights, self.batch) | |
inputs = tf.reshape(inputs, [self._batch_size, num_digits * self._config.embedding_size]) | |
hidden_size = getattr(self._config, 'a1_hidden_size') or self._config.hidden_size | |
hidden = tf.contrib.layers.fully_connected( | |
inputs, hidden_size, scope="hidden", activation_fn=tf.nn.tanh, | |
# weights_initializer=tf.orthogonal_initializer, | |
) | |
# If we are using an eos penalty, then we are going to predict a third attribute per message, which is | |
# whether we want to zero out the remaining messages. If we predict to do it, then every message | |
# thereafter will be masked. | |
output_size = 3 if self._config.eos_penalty else 2 | |
logits = tf.contrib.layers.fully_connected( | |
hidden, output_size*num_binary_messages, | |
activation_fn=None, scope="logits", | |
### This is commented out because with it included the last part of the weight matrix becomes 0. | |
### I can't explain that and it bears further inquiry | |
# weights_initializer=tf.orthogonal_initializer, | |
) | |
if self._config.eos_penalty: | |
# We additionally predict a fourth attribute meant to be a scalar that pushes the EOS probability. | |
scalar_eos = tf.contrib.layers.fully_connected( | |
logits, 1, activation_fn=tf.sigmoid, scope="scalar_eos", | |
weights_initializer=tf.orthogonal_initializer) | |
a1_logits = tf.reshape(logits, [self._batch_size, num_binary_messages, output_size]) | |
tau = tf.get_variable("QTemperature", initializer=tf.constant_initializer(1.0), trainable=True, shape=()) | |
q_y = tf.contrib.distributions.RelaxedOneHotCategorical(tau, logits=a1_logits) | |
y = q_y.sample() | |
y_hard = tf.cast(tf.one_hot(tf.argmax(y, -1), output_size), y.dtype) | |
if self._config.eos_penalty: | |
# append a zero out onto the back so that argmax doesn't use an incorrect indice. | |
self.pre_mask_argmax_messages = tf.argmax(y, -1) | |
one_hot = np.array([0]*(output_size - 1) + [1]).astype(np.float32) | |
concat_one_hot = tf.expand_dims(tf.expand_dims(tf.convert_to_tensor(one_hot), 0), 0) | |
concat_one_hot = tf.tile(concat_one_hot, tf.stack([tf.shape(y_hard)[0], 1, 1])) | |
concat_y_hard = tf.concat([y_hard, concat_one_hot], 1) | |
# we need to find the first message that's predicting a 2 and then zero out from there. | |
first_zeros = tf.argmax(tf.to_int32(tf.reduce_all(tf.equal(concat_y_hard, one_hot), 2)), 1) | |
self.mask = mask = tf.to_float(tf.sequence_mask(first_zeros, num_binary_messages + 1)) | |
argmax_messages = ((tf.argmax(concat_y_hard, 2) + 1) * tf.to_int64(mask))[:, :num_binary_messages] | |
y_hard = tf.one_hot(argmax_messages, output_size) | |
self.argmax_messages = argmax_messages | |
else: | |
self.argmax_messages = tf.argmax(y, -1) | |
self.messages = messages = tf.stop_gradient(y_hard - y) + y | |
with tf.variable_scope("A2"): | |
print(messages) | |
projection = tf.contrib.layers.fully_connected(messages, self._config.embedding_size, | |
activation_fn=None, scope="embeddings") | |
messages = tf.reshape(projection, [self._batch_size, self._config.embedding_size * num_binary_messages]) | |
hidden_size = getattr(self._config, 'a2_hidden_size') or self._config.hidden_size | |
hidden = tf.contrib.layers.fully_connected( | |
messages, hidden_size, scope="hidden", activation_fn=tf.nn.tanh, | |
# weights_initializer=tf.orthogonal_initializer, | |
) | |
a2_logits = tf.contrib.layers.fully_connected( | |
hidden, 10*num_digits**2, activation_fn=None, scope="a2_logits", | |
# weights_initializer=tf.orthogonal_initializer, | |
) | |
a2_logits = tf.reshape(a2_logits, [self._batch_size, num_digits, 10*num_digits]) | |
# TODO: Add the entropy penalty. | |
# entropy = -1 * tf.reduce_sum(softmax * tf.log(softmax), 2) | |
# entropies = [tf.reduce_sum(entropy, 1)] | |
a1_vars = [v for v in tf.trainable_variables() if 'A1' in v.name and 'temperature' not in v.name] | |
a1_l2_norm = tf.add_n([tf.nn.l2_loss(v) for v in a1_vars]) | |
a2_vars = [v for v in tf.trainable_variables() if 'A2' in v.name and 'temperature' not in v.name] | |
a2_l2_norm = tf.add_n([tf.nn.l2_loss(v) for v in a2_vars]) | |
weight_summaries = tf.summary.merge([ | |
tf.summary.scalar('a1_l2_norm', a1_l2_norm), | |
tf.summary.scalar('a2_l2_norm', a2_l2_norm), | |
]) | |
losses = [ | |
tf.nn.softmax_cross_entropy_with_logits(logits=a2_logits[:, i], labels=label[:, i]) | |
for i in range(num_digits) | |
] | |
losses = [tf.reduce_mean(loss) for loss in losses] | |
total_loss = tf.add_n(losses) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment