-
-
Save luan1412167/59fb1d5f4b529abfc2212688fe7f2561 to your computer and use it in GitHub Desktop.
Pytorch instance-wise weighted cross-entropy loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
def log_sum_exp(x): | |
# See implementation detail in | |
# http://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ | |
# b is a shift factor. see link. | |
# x.size() = [N, C]: | |
b, _ = torch.max(x, 1) | |
y = b + torch.log(torch.exp(x - b.expand_as(x)).sum(1)) | |
# y.size() = [N, 1]. Squeeze to [N] and return | |
return y.squeeze(1) | |
def class_select(logits, target): | |
# in numpy, this would be logits[:, target]. | |
batch_size, num_classes = logits.size() | |
if target.is_cuda: | |
device = target.data.get_device() | |
one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes) | |
.long() | |
.repeat(batch_size, 1) | |
.cuda(device) | |
.eq(target.data.repeat(num_classes, 1).t())) | |
else: | |
one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes) | |
.long() | |
.repeat(batch_size, 1) | |
.eq(target.data.repeat(num_classes, 1).t())) | |
return logits.masked_select(one_hot_mask) | |
def cross_entropy_with_weights(logits, target, weights=None): | |
assert logits.dim() == 2 | |
assert not target.requires_grad | |
target = target.squeeze(1) if target.dim() == 2 else target | |
assert target.dim() == 1 | |
loss = log_sum_exp(logits) - class_select(logits, target) | |
if weights is not None: | |
# loss.size() = [N]. Assert weights has the same shape | |
assert list(loss.size()) == list(weights.size()) | |
# Weight the loss | |
loss = loss * weights | |
return loss | |
class CrossEntropyLoss(nn.Module): | |
""" | |
Cross entropy with instance-wise weights. Leave `aggregate` to None to obtain a loss | |
vector of shape (batch_size,). | |
""" | |
def __init__(self, aggregate='mean'): | |
super(CrossEntropyLoss, self).__init__() | |
assert aggregate in ['sum', 'mean', None] | |
self.aggregate = aggregate | |
def forward(self, input, target, weights=None): | |
if self.aggregate == 'sum': | |
return cross_entropy_with_weights(input, target, weights).sum() | |
elif self.aggregate == 'mean': | |
return cross_entropy_with_weights(input, target, weights).mean() | |
elif self.aggregate is None: | |
return cross_entropy_with_weights(input, target, weights) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment