Last active
April 19, 2025 02:14
-
-
Save tam17aki/a3511ed5a400f2a6ac049cbdee36897d to your computer and use it in GitHub Desktop.
An implementation of hyper LSTM.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# Copyright (C) 2017 by Akira TAMAMORI | |
# Copyright (C) 2016 by hardmaru | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import tensorflow as tf | |
import numpy as np | |
# Orthogonal Initializer from | |
# https://github.com/OlavHN/bnlstm | |
def orthogonal(shape): | |
flat_shape = (shape[0], np.prod(shape[1:])) | |
a = np.random.normal(0.0, 1.0, flat_shape) | |
u, _, v = np.linalg.svd(a, full_matrices=False) | |
q = u if u.shape == flat_shape else v | |
return q.reshape(shape) | |
def lstm_ortho_initializer(scale=1.0): | |
def _initializer(shape, dtype=tf.float32, partition_info=None): | |
size_x = shape[0] | |
size_h = shape[1] / 4 # assumes lstm. | |
t = np.zeros(shape) | |
t[:, :size_h] = orthogonal([size_x, size_h]) * scale | |
t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale | |
t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale | |
t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale | |
return tf.constant(t, dtype) | |
return _initializer | |
def layer_norm_all(h, batch_size, base, num_units, scope="layer_norm", | |
reuse=False, gamma_start=1.0, epsilon=1e-3, use_bias=True): | |
# Layer Norm (faster version, but not using defun) | |
# | |
# Performas layer norm on multiple base at once (ie, i, g, j, o for lstm) | |
# | |
# Reshapes h in to perform layer norm in parallel | |
h_reshape = tf.reshape(h, [batch_size, base, num_units]) | |
mean = tf.reduce_mean(h_reshape, [2], keep_dims=True) | |
var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True) | |
epsilon = tf.constant(epsilon) | |
rstd = tf.rsqrt(var + epsilon) | |
h_reshape = (h_reshape - mean) * rstd | |
# reshape back to original | |
h = tf.reshape(h_reshape, [batch_size, base * num_units]) | |
with tf.variable_scope(scope): | |
if reuse is True: | |
tf.get_variable_scope().reuse_variables() | |
gamma = tf.get_variable( | |
'ln_gamma', [4 * num_units], | |
initializer=tf.constant_initializer(gamma_start)) | |
if use_bias: | |
beta = tf.get_variable( | |
'ln_beta', [4 * num_units], | |
initializer=tf.constant_initializer(0.0)) | |
if use_bias: | |
return gamma * h + beta | |
return gamma * h | |
def layer_norm(x, num_units, scope="layer_norm", reuse=False, gamma_start=1.0, | |
epsilon=1e-3, use_bias=True): | |
axes = [1] | |
mean = tf.reduce_mean(x, axes, keep_dims=True) | |
x_shifted = x - mean | |
var = tf.reduce_mean(tf.square(x_shifted), axes, keep_dims=True) | |
inv_std = tf.rsqrt(var + epsilon) | |
with tf.variable_scope(scope): | |
if reuse is True: | |
tf.get_variable_scope().reuse_variables() | |
gamma = tf.get_variable( | |
'ln_gamma', [num_units], | |
initializer=tf.constant_initializer(gamma_start)) | |
if use_bias: | |
beta = tf.get_variable( | |
'ln_beta', [num_units], | |
initializer=tf.constant_initializer(0.0)) | |
output = gamma * (x_shifted) * inv_std | |
if use_bias: | |
output = output + beta | |
return output | |
def super_linear(x, output_size, scope=None, reuse=False, | |
init_w="ortho", weight_start=0.0, use_bias=True, | |
bias_start=0.0, input_size=None): | |
# support function doing linear operation. uses ortho initializer defined | |
# earlier. | |
shape = x.get_shape().as_list() | |
with tf.variable_scope(scope or "linear"): | |
if reuse is True: | |
tf.get_variable_scope().reuse_variables() | |
w_init = None # uniform | |
if input_size is None: | |
x_size = shape[1] | |
else: | |
x_size = input_size | |
if init_w == "zeros": | |
w_init = tf.constant_initializer(0.0) | |
elif init_w == "constant": | |
w_init = tf.constant_initializer(weight_start) | |
elif init_w == "gaussian": | |
w_init = tf.random_normal_initializer(stddev=weight_start) | |
elif init_w == "ortho": | |
w_init = lstm_ortho_initializer(1.0) | |
w = tf.get_variable("super_linear_w", | |
[x_size, output_size], | |
tf.float32, initializer=w_init) | |
if use_bias: | |
b = tf.get_variable( | |
"super_linear_b", [output_size], tf.float32, | |
initializer=tf.constant_initializer(bias_start)) | |
return tf.matmul(x, w) + b | |
return tf.matmul(x, w) | |
def hyper_norm(layer, hyper_output, embedding_size, num_units, | |
scope="hyper", use_bias=True): | |
''' | |
HyperNetwork norm operator | |
provides context-dependent weights | |
layer: layer to apply operation on | |
hyper_output: output of the hypernetwork cell at time t | |
embedding_size: embedding size of the output vector (see paper) | |
num_units: number of hidden units in main rnn | |
''' | |
# recurrent batch norm init trick (https://arxiv.org/abs/1603.09025). | |
init_gamma = 0.10 # cooijmans' da man. | |
with tf.variable_scope(scope): | |
zw = super_linear(hyper_output, embedding_size, init_w="constant", | |
weight_start=0.00, use_bias=True, | |
bias_start=1.0, scope="zw") | |
alpha = super_linear(zw, num_units, init_w="constant", | |
weight_start=init_gamma / embedding_size, | |
use_bias=False, scope="alpha") | |
result = tf.mul(alpha, layer) | |
return result | |
def hyper_bias(layer, hyper_output, embedding_size, num_units, | |
scope="hyper"): | |
''' | |
HyperNetwork norm operator | |
provides context-dependent bias | |
layer: layer to apply operation on | |
hyper_output: output of the hypernetwork cell at time t | |
embedding_size: embedding size of the output vector (see paper) | |
num_units: number of hidden units in main rnn | |
''' | |
with tf.variable_scope(scope): | |
zb = super_linear(hyper_output, embedding_size, init_w="gaussian", | |
weight_start=0.01, use_bias=False, | |
bias_start=0.0, scope="zb") | |
beta = super_linear(zb, num_units, init_w="constant", | |
weight_start=0.00, use_bias=False, scope="beta") | |
return layer + beta | |
class LSTMCell(tf.contrib.rnn.RNNCell): | |
""" | |
Layer-Norm, with Ortho Initialization and | |
Recurrent Dropout without Memory Loss. | |
https://arxiv.org/abs/1607.06450 - Layer Norm | |
https://arxiv.org/abs/1603.05118 - Recurrent Dropout without Memory Loss | |
derived from | |
https://github.com/OlavHN/bnlstm | |
https://github.com/LeavesBreathe/tensorflow_with_latest_papers | |
""" | |
def __init__(self, num_units, forget_bias=1.0, use_layer_norm=False, | |
use_recurrent_dropout=False, dropout_keep_prob=0.90): | |
"""Initialize the Layer Norm LSTM cell. | |
Args: | |
num_units: int, The number of units in the LSTM cell. | |
forget_bias: float, The bias added to forget gates (default 1.0). | |
use_recurrent_dropout: float, Whether to use Recurrent Dropout | |
(default False) | |
dropout_keep_prob: float, dropout keep probability (default 0.90) | |
""" | |
self.num_units = num_units | |
self.forget_bias = forget_bias | |
self.use_layer_norm = use_layer_norm | |
self.use_recurrent_dropout = use_recurrent_dropout | |
self.dropout_keep_prob = dropout_keep_prob | |
@property | |
def output_size(self): | |
return self.num_units | |
@property | |
def state_size(self): | |
return tf.contrib.rnn.LSTMStateTuple(self.num_units, self.num_units) | |
def __call__(self, x, state, scope=None): | |
with tf.variable_scope(scope or type(self).__name__): | |
c, h = state | |
batch_size = x.get_shape().as_list()[0] | |
x_size = x.get_shape().as_list()[1] | |
w_init = None # uniform | |
h_init = lstm_ortho_initializer() | |
W_xh = tf.get_variable( | |
'W_xh', [x_size, 4 * self.num_units], initializer=w_init) | |
W_hh = tf.get_variable( | |
'W_hh_i', [self.num_units, 4 * self.num_units], | |
initializer=h_init) | |
W_full = tf.concat([W_xh, W_hh], 0) | |
bias = tf.get_variable( | |
'bias', [4 * self.num_units], | |
initializer=tf.constant_initializer(0.0)) | |
concat = tf.concat([x, h], 1) # concat for speed. | |
concat = tf.matmul(concat, W_full) + bias | |
# new way of doing layer norm (faster) | |
if self.use_layer_norm: | |
concat = layer_norm_all( | |
concat, batch_size, 4, self.num_units, 'ln') | |
# i = input_gate, j = new_input, f = forget_gate, o = output_gate | |
i, j, f, o = tf.split(concat, 4, 1) | |
if self.use_recurrent_dropout: | |
g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob) | |
else: | |
g = tf.tanh(j) | |
new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g | |
if self.use_layer_norm: | |
new_h = tf.tanh(layer_norm( | |
new_c, self.num_units, 'ln_c')) * tf.sigmoid(o) | |
else: | |
new_h = tf.tanh(new_c) * tf.sigmoid(o) | |
return new_h, tf.contrib.rnn.LSTMStateTuple(new_c, new_h) | |
class HyperLSTMCell(tf.contrib.rnn.RNNCell): | |
''' | |
HyperLSTM, with Ortho Initialization, | |
Layer Norm and Recurrent Dropout without Memory Loss. | |
https://arxiv.org/abs/1609.09106 | |
''' | |
def __init__(self, num_units, forget_bias=1.0, | |
use_recurrent_dropout=False, dropout_keep_prob=0.90, | |
use_layer_norm=True, | |
hyper_num_units=128, hyper_embedding_size=16, | |
hyper_use_recurrent_dropout=False): | |
'''Initialize the Layer Norm HyperLSTM cell. | |
Args: | |
num_units: int, The number of units in the LSTM cell. | |
forget_bias: float, The bias added to forget gates (default 1.0). | |
use_recurrent_dropout: float, Whether to use Recurrent Dropout | |
(default False) | |
dropout_keep_prob: float, dropout keep probability (default 0.90) | |
use_layer_norm: boolean. (default True) | |
Controls whether we use LayerNorm layers in main LSTM and | |
HyperLSTM cell. | |
hyper_num_units: int, number of units in HyperLSTM cell. | |
(default is 128, recommend experimenting with 256 for larger tasks) | |
hyper_embedding_size: int, size of signals emitted from HyperLSTM | |
cell. (default is 4, recommend trying larger | |
values but larger is not always better) | |
hyper_use_recurrent_dropout: boolean. (default False) | |
Controls whether HyperLSTM cell also uses recurrent dropout. | |
(Not in Paper.) | |
Recommend turning this on only if hyper_num_units becomes very | |
large (>= 512) | |
''' | |
self.num_units = num_units | |
self.forget_bias = forget_bias | |
self.use_recurrent_dropout = use_recurrent_dropout | |
self.dropout_keep_prob = dropout_keep_prob | |
self.use_layer_norm = use_layer_norm | |
self.hyper_num_units = hyper_num_units | |
self.hyper_embedding_size = hyper_embedding_size | |
self.hyper_use_recurrent_dropout = hyper_use_recurrent_dropout | |
self.total_num_units = self.num_units + self.hyper_num_units | |
self.hyper_cell = LSTMCell( | |
hyper_num_units, | |
use_recurrent_dropout=hyper_use_recurrent_dropout, | |
use_layer_norm=use_layer_norm, | |
dropout_keep_prob=dropout_keep_prob) | |
@property | |
def output_size(self): | |
return self.num_units | |
@property | |
def state_size(self): | |
return tf.contrib.rnn.LSTMStateTuple( | |
self.num_units + self.hyper_num_units, | |
self.num_units + self.hyper_num_units) | |
def __call__(self, x, state, timestep=0, scope=None): | |
with tf.variable_scope(scope or type(self).__name__): | |
total_c, total_h = state | |
c = total_c[:, 0:self.num_units] | |
h = total_h[:, 0:self.num_units] | |
hyper_state = tf.contrib.rnn.LSTMStateTuple( | |
total_c[:, self.num_units:], | |
total_h[:, self.num_units:]) | |
w_init = None # uniform | |
h_init = lstm_ortho_initializer(1.0) | |
x_size = x.get_shape().as_list()[1] | |
embedding_size = self.hyper_embedding_size | |
num_units = self.num_units | |
batch_size = x.get_shape().as_list()[0] | |
W_xh = tf.get_variable('W_xh', | |
[x_size, 4 * num_units], initializer=w_init) | |
W_hh = tf.get_variable('W_hh', | |
[num_units, 4 * num_units], | |
initializer=h_init) | |
bias = tf.get_variable('bias', | |
[4 * num_units], | |
initializer=tf.constant_initializer(0.0)) | |
# concatenate the input and hidden states for hyperlstm input | |
hyper_input = tf.concat([x, h], 1) | |
hyper_output, hyper_new_state = self.hyper_cell( | |
hyper_input, hyper_state) | |
xh = tf.matmul(x, W_xh) | |
hh = tf.matmul(h, W_hh) | |
# split Wxh contributions | |
ix, jx, fx, ox = tf.split(xh, 4, 1) | |
ix = hyper_norm(ix, hyper_output, embedding_size, | |
num_units, 'hyper_ix') | |
jx = hyper_norm(jx, hyper_output, embedding_size, | |
num_units, 'hyper_jx') | |
fx = hyper_norm(fx, hyper_output, embedding_size, | |
num_units, 'hyper_fx') | |
ox = hyper_norm(ox, hyper_output, embedding_size, | |
num_units, 'hyper_ox') | |
# split Whh contributions | |
ih, jh, fh, oh = tf.split(hh, 4, 1) | |
ih = hyper_norm(ih, hyper_output, embedding_size, | |
num_units, 'hyper_ih') | |
jh = hyper_norm(jh, hyper_output, embedding_size, | |
num_units, 'hyper_jh') | |
fh = hyper_norm(fh, hyper_output, embedding_size, | |
num_units, 'hyper_fh') | |
oh = hyper_norm(oh, hyper_output, embedding_size, | |
num_units, 'hyper_oh') | |
# split bias | |
ib, jb, fb, ob = tf.split(bias, 4, 0) # bias is to be broadcasted. | |
ib = hyper_bias(ib, hyper_output, embedding_size, | |
num_units, 'hyper_ib') | |
jb = hyper_bias(jb, hyper_output, embedding_size, | |
num_units, 'hyper_jb') | |
fb = hyper_bias(fb, hyper_output, embedding_size, | |
num_units, 'hyper_fb') | |
ob = hyper_bias(ob, hyper_output, embedding_size, | |
num_units, 'hyper_ob') | |
# i = input_gate, j = new_input, f = forget_gate, o = output_gate | |
i = ix + ih + ib | |
j = jx + jh + jb | |
f = fx + fh + fb | |
o = ox + oh + ob | |
if self.use_layer_norm: | |
concat = tf.concat([i, j, f, o], 1) | |
concat = layer_norm_all( | |
concat, batch_size, 4, num_units, 'ln_all') | |
i, j, f, o = tf.split(concat, 4, 1) | |
if self.use_recurrent_dropout: | |
g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob) | |
else: | |
g = tf.tanh(j) | |
new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g | |
if self.use_layer_norm: | |
new_h = tf.tanh(layer_norm( | |
new_c, num_units, 'ln_c')) * tf.sigmoid(o) | |
else: | |
new_h = tf.tanh(new_c) * tf.sigmoid(o) | |
hyper_c, hyper_h = hyper_new_state | |
new_total_c = tf.concat([new_c, hyper_c], 1) | |
new_total_h = tf.concat([new_h, hyper_h], 1) | |
return new_h, tf.contrib.rnn.LSTMStateTuple(new_total_c, new_total_h) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment