Created
November 25, 2021 23:23
-
-
Save ShairozS/2e72e3552eaeec767997129261213127 to your computer and use it in GitHub Desktop.
Contrastive pair loss and triplet loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
def form_triplets(inA, inB): | |
''' | |
Form triplets from two tensors of embeddings. It is assumed that the embeddings at corresponding batch positions are similar | |
and all other batch positions are dissimilar | |
i.e inA[i] ~ inB[i] and inA[i] !~ inB[j] for all i =! j | |
''' | |
b, emb_size = inA.shape | |
perms = b**2 | |
labels = [0]*perms; sim_idxs = [(0 + i*b) + i for i in range(b)] | |
for idx in sim_idxs: | |
labels[idx] = 1 | |
labels = torch.Tensor(labels) | |
labels = labels.type(torch.BoolTensor).to(inA.device) | |
anchors = inA.repeat(b, 1)[~labels] | |
negatives = torch.cat([inB[i,:].repeat(b,1) for i in range(b)])[~labels] | |
positives = inB.repeat(b, 1)[~labels] | |
return(anchors, positives, negatives) | |
def form_pairs(inA, inB): | |
''' | |
Form pairs from two tensors of embeddings. It is assumed that the embeddings at corresponding batch positions are similar | |
and all other batch positions are dissimilar | |
i.e inA[i] ~ inB[i] and inA[i] !~ inB[j] for all i =! j | |
''' | |
b, emb_size = inA.shape | |
perms = b**2 | |
labels = [0]*perms; sim_idxs = [(0 + i*b) + i for i in range(b)] | |
for idx in sim_idxs: | |
labels[idx] = 1 | |
labels = torch.Tensor(labels) | |
return(inA.repeat(b, 1), torch.cat([inB[i,:].repeat(b,1) for i in range(b)]), labels.type(torch.LongTensor).to(inA.device)) | |
class ContrastiveLoss(torch.nn.Module): | |
""" | |
Contrastive loss function. | |
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf | |
args: | |
distance (function): A function that returns the distance between two tensors - should be a valid metric over R; default= L2 distance | |
margin (scalar): The margin value between positive and negative class ; default=1.0 | |
miner (function, optional): A function that calculates similarity labels [0,1] on the input if no labels are explicitly provided - should return (embs1, embs2, labels) | |
""" | |
def __init__(self, | |
distance = lambda x,y: torch.pow(x-y, 2).sum(1), | |
margin=1.0, | |
mode='pairs', | |
batch_size=None, | |
temperature=0.5): | |
super(ContrastiveLoss, self).__init__() | |
self.margin = margin | |
self.distance = distance | |
self.mode = mode | |
def forward(self, x, y): | |
if self.mode == 'pairs': | |
return(self.forward_pairs(x, y)) | |
elif self.mode == 'triplets': | |
return(self.forward_triplets(x, y)) | |
elif self.mode == 'ntxent': | |
return(self.forward_ntxent(x, y)) | |
def forward_triplets(self, x, y): | |
a, p, n = form_triplets(x,y) | |
return(torch.nn.functional.triplet_margin_with_distance_loss(a,p,n, margin=self.margin, distance_function=self.distance)) | |
def forward_pairs(self, x, y, label=None): | |
''' | |
Return the contrastive loss between two similar or dissimilar outputs | |
Args: | |
x (torch.Tensor) : The first input tensor (B, N) | |
y (torch.Tensor) : The second input tensor (B,N) | |
label (torch.Tensor, optional) : A tensor with elements either 0 or 1 indicating dissimilar or similar (B, 1) | |
''' | |
assert x.shape==y.shape, str(x.shape) + "does not match input 2: " + str(y.shape) | |
x, y, label = form_pairs(x,y) | |
distance = self.distance(x,y) | |
# When the label is 1 (similar) - the loss is the distance between the embeddings | |
# When the label is 0 (dissimilar) - the loss is the distance between the embeddings and a margin | |
loss_contrastive = torch.mean((label) * distance + | |
(1-label) * torch.clamp(self.margin - distance, min=0.0)) | |
return loss_contrastive |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment