Skip to content

Instantly share code, notes, and snippets.

@markloyman
Created January 23, 2018 20:50
Show Gist options
  • Save markloyman/a29f3cf929aa4ac110e68f9b357bff02 to your computer and use it in GitHub Desktop.
Save markloyman/a29f3cf929aa4ac110e68f9b357bff02 to your computer and use it in GitHub Desktop.
Triplet Loss
def triplet_loss(_, y_pred, triplet_margin=1):
'''
Assume: y_pred shape is (batch_size, 2)
Example for how to construct such a corresponding triplet network:
def euclidean_distance(vects):
x, y = vects
return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))
def trip_output_shape(shapes):
shape1, shape2 = shapes
return (shape1[0], 1)
def build_triplet(embed_network, input_shape):
img_ref = Input(shape=input_shape)
img_pos = Input(shape=input_shape)
img_neg = Input(shape=input_shape)
embed_ref = embed_network(img_ref)
embed_pos = embed_network(img_pos)
embed_neg = embed_network(img_neg)
distance_pos = Lambda(euclidean_distance, output_shape=dist_output_shape)([embed_ref, embed_pos])
distance_neg = Lambda(euclidean_distance, output_shape=dist_output_shape)([embed_ref, embed_neg])
output_layer = Lambda(lambda vects: K.concatenate(vects, axis=1))([distance_pos, distance_neg])
model = Model( inputs=[img_ref, img_pos, img_neg],
outputs = output_layer,
name='triplet-network')
return model
'''
margin = K.constant(triplet_margin)
subtraction = K.constant([1, -1], shape=(2, 1))
diff = K.dot(K.square(y_pred), subtraction)
loss = K.maximum(K.constant(0), margin + diff)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment