Created
August 18, 2018 11:23
-
-
Save peter0749/622655e2b284fe268b832cb42cced423 to your computer and use it in GitHub Desktop.
Keras Layer/Function of Learning a Deep Listwise Context Model for Ranking Refinement
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def att_loss(y_true, y_pred): | |
def att_(x): | |
a = tf.where(x>0, K.exp(x), K.zeros_like(x)) | |
return a / (K.sum(a, axis=-1, keepdims=True)+K.epsilon()) | |
y_true_a = att_(y_true) | |
y_pred_a = att_(y_pred) | |
loss = K.mean(K.binary_crossentropy(y_true_a, y_pred_a), axis=-1) | |
return loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class feature_abstraction_module(object): | |
def __init__(self): | |
self.inner_gru = TimeDistributed(CuDNNGRU(beta, return_sequences=False, recurrent_regularizer=l2(0.01)), name='feature_extractor') | |
def transform(self, x): | |
return self.inner_gru(x) | |
class ReRankLayer(Layer): | |
def __init__(self, k, **kwargs): | |
self.k = k | |
super(ReRankLayer, self).__init__(**kwargs) | |
def build(self, input_shape): # [sn, ot] | |
assert type(input_shape) is list | |
assert len(input_shape)==2 | |
assert input_shape[0][-1]==input_shape[1][-1] | |
# Create a trainable weight variable for this layer. | |
self.alpha = input_shape[0][-1] | |
self.W_phi = self.add_weight(name='W_phi', | |
shape=(self.alpha, self.k, self.alpha), | |
initializer='glorot_normal', | |
trainable=True) | |
self.bias_phi = self.add_weight(name='bias_phi', | |
shape=(self.alpha, self.k), | |
initializer='zeros', | |
trainable=True) | |
self.V_phi = self.add_weight(name='V_phi', | |
shape=(self.k,), | |
initializer='glorot_uniform', | |
trainable=True) | |
super(ReRankLayer, self).build(input_shape) # Be sure to call this at the end | |
def call(self, input_layers): | |
sn, ot = input_layers | |
# shape of sn: (batch_size, alpha) | |
# shape of ot: (batch_size, alpha) | |
# shape of W_phi: (alpha, k, alpha) | |
# shape of bias_phi: (alpha, k) | |
# shape of V_phi: (k,) | |
batch_size = K.shape(sn)[0] | |
batch_weights = K.tile(K.expand_dims(self.W_phi, axis=0), [batch_size, 1, 1, 1]) # shape: (batch_size, alpha, k, alpha) | |
batch_bias = K.tile(K.expand_dims(self.bias_phi, axis=0), [batch_size, 1, 1]) # shape: (batch_size, alpha, k) | |
batch_V = K.tile(K.expand_dims(self.V_phi, axis=0), [batch_size, 1]) # shape: (batch_size , k) | |
linear = K.reshape(K.batch_dot(K.reshape(batch_weights, (batch_size, self.alpha*self.k, self.alpha)), sn, axes=[2, 1]), (batch_size, self.alpha, self.k)) + batch_bias # (batch_size, alpha, k) | |
nonlinear = K.tanh(linear) # (batch_size, alpha, k) | |
pre_att = K.batch_dot(ot, nonlinear, axes=[1, 1]) # (batch_size, k) | |
output = K.expand_dims(K.batch_dot(batch_V, pre_att, axes=[1, 1]), axis=-1) # (batch_size, 1) | |
return output | |
def compute_output_shape(self, input_shape): | |
assert type(input_shape) is list | |
assert len(input_shape)==2 | |
return (input_shape[0][0],1) |
Author
peter0749
commented
Aug 21, 2018
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment