Skip to content

Instantly share code, notes, and snippets.

@dpressel
Last active June 8, 2020 17:18
Show Gist options
  • Save dpressel/38cf85b4d3ed7991b4750202294f5220 to your computer and use it in GitHub Desktop.
Save dpressel/38cf85b4d3ed7991b4750202294f5220 to your computer and use it in GitHub Desktop.
Highway Connections in TensorFlow
# This is a stack of res conns
def skip_conns(inputs, wsz_all, n):
for i in range(n):
with tf.variable_scope("skip-%d" % i):
W_p = tf.get_variable("W_p", [wsz_all, wsz_all])
b_p = tf.get_variable("B_p", [1, wsz_all], initializer=tf.constant_initializer(0.0))
proj = tf.nn.relu(tf.matmul(inputs, W_p) + b_p, "relu")
inputs = inputs + proj
return inputs
# This is a stack of highway conns.
def highway_conns(inputs, wsz_all, n):
for i in range(n):
with tf.variable_scope("highway-%d" % i):
W_p = tf.get_variable("W_p", [wsz_all, wsz_all])
b_p = tf.get_variable("B_p", [1, wsz_all], initializer=tf.constant_initializer(0.0))
proj = tf.nn.relu(tf.matmul(inputs, W_p) + b_p, "relu-proj")
W_t = tf.get_variable("W_t", [wsz_all, wsz_all])
b_t = tf.get_variable("B_t", [1, wsz_all], initializer=tf.constant_initializer(-2.0))
transform = tf.nn.sigmoid(tf.matmul(inputs, W_t) + b_t, "sigmoid-transform")
inputs = tf.multiply(transform, proj) + tf.multiply(inputs, 1 - transform)
return inputs
@dpressel
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment