Created
September 1, 2021 01:57
-
-
Save vfmatzkin/891943594a94f4207e389f235287b6ce to your computer and use it in GitHub Desktop.
Implementation of the Brabandere Loss term, used in https://arxiv.org/abs/1708.02551
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def brabandere_loss(pred, alpha=1, beta=1, gamma=1e-3, dv=.05, dd=2): | |
""" | |
Implementation of the proposed loss in: | |
Semantic Instance Segmentation with a Discriminative Loss Function (https://arxiv.org/abs/1708.02551) | |
in which, the loss is a sum of three terms: | |
1) A intra-cluster attraction force. | |
2) A inter-cluster repealling force. | |
3) A regularization term for the mean codes. | |
This loss in unsupervised, so only batches of codes for each cluster are needed. | |
:param pred: List of batches of embeddings from which the loss will be calculated. | |
:param alpha: L_var (intra-cluster) loss weight. | |
:param beta: L_dist (inter-cluster) loss weight. | |
:param gamma: Regularization loss weight | |
:return: | |
""" | |
C = len(pred) # Number of clusters | |
o_loss = torch.zeros(3, dtype=torch.float, device='cuda:0') # Lvar, Ldist, Lreg | |
coeffs = torch.tensor([alpha, beta, gamma], dtype=torch.float, device='cuda:0') | |
for c, cluster in enumerate(pred): # Each cluster | |
Nc = len(cluster) | |
mean_cluster_a = torch.mean(cluster, dim=0) | |
o_loss[0] = torch.tensor(0., device='cuda:0') # Lvar | |
for emb in cluster: # Each embedding | |
# Sum only if it's far from the cluster | |
o_loss[0] += 1 / (C * Nc) * torch.pow(torch.max(torch.tensor(0., device='cuda:0'), torch.norm(mean_cluster_a - emb, 2) - dv), | |
2) | |
# Ldist | |
for j in range(c + 1, C): # The other clusters | |
mean_cluster_b = torch.mean(pred[j], dim=0) | |
o_loss[1] += 1 / (C * C - C) * torch.pow( | |
torch.max(torch.tensor(0., device='cuda:0'), 2 * dd - torch.norm(mean_cluster_a - mean_cluster_b, 2)), 2) | |
o_loss[2] += 1 / C * torch.norm(mean_cluster_a, 2) # Lreg | |
return torch.dot(coeffs, o_loss) # weighted loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment