Created
January 23, 2018 20:50
-
-
Save markloyman/a29f3cf929aa4ac110e68f9b357bff02 to your computer and use it in GitHub Desktop.
Triplet Loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def triplet_loss(_, y_pred, triplet_margin=1): | |
''' | |
Assume: y_pred shape is (batch_size, 2) | |
Example for how to construct such a corresponding triplet network: | |
def euclidean_distance(vects): | |
x, y = vects | |
return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon())) | |
def trip_output_shape(shapes): | |
shape1, shape2 = shapes | |
return (shape1[0], 1) | |
def build_triplet(embed_network, input_shape): | |
img_ref = Input(shape=input_shape) | |
img_pos = Input(shape=input_shape) | |
img_neg = Input(shape=input_shape) | |
embed_ref = embed_network(img_ref) | |
embed_pos = embed_network(img_pos) | |
embed_neg = embed_network(img_neg) | |
distance_pos = Lambda(euclidean_distance, output_shape=dist_output_shape)([embed_ref, embed_pos]) | |
distance_neg = Lambda(euclidean_distance, output_shape=dist_output_shape)([embed_ref, embed_neg]) | |
output_layer = Lambda(lambda vects: K.concatenate(vects, axis=1))([distance_pos, distance_neg]) | |
model = Model( inputs=[img_ref, img_pos, img_neg], | |
outputs = output_layer, | |
name='triplet-network') | |
return model | |
''' | |
margin = K.constant(triplet_margin) | |
subtraction = K.constant([1, -1], shape=(2, 1)) | |
diff = K.dot(K.square(y_pred), subtraction) | |
loss = K.maximum(K.constant(0), margin + diff) | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment