Created
June 11, 2019 03:48
-
-
Save ashunigion/17d11acb4555761f3b40271194dd67e9 to your computer and use it in GitHub Desktop.
custom loss function with Negative Sampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NegativeSamplingLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, input_vectors, output_vectors, noise_vectors): | |
batch_size, embed_size = input_vectors.shape | |
# Input vectors should be a batch of column vectors | |
input_vectors = input_vectors.view(batch_size, embed_size, 1) | |
# Output vectors should be a batch of row vectors | |
output_vectors = output_vectors.view(batch_size, 1, embed_size) | |
# bmm = batch matrix multiplication | |
# correct log-sigmoid loss | |
out_loss = torch.bmm(output_vectors, input_vectors).sigmoid().log() | |
out_loss = out_loss.squeeze() | |
# incorrect log-sigmoid loss | |
noise_loss = torch.bmm(noise_vectors.neg(), input_vectors).sigmoid().log() | |
noise_loss = noise_loss.squeeze().sum(1) # sum the losses over the sample of noise vectors | |
# negate and sum correct and noisy log-sigmoid losses | |
# return average batch loss | |
return -(out_loss + noise_loss).mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment