Skip to content

Instantly share code, notes, and snippets.

@ashunigion
Created June 11, 2019 03:48
Show Gist options
  • Save ashunigion/17d11acb4555761f3b40271194dd67e9 to your computer and use it in GitHub Desktop.
Save ashunigion/17d11acb4555761f3b40271194dd67e9 to your computer and use it in GitHub Desktop.
custom loss function with Negative Sampling
class NegativeSamplingLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_vectors, output_vectors, noise_vectors):
batch_size, embed_size = input_vectors.shape
# Input vectors should be a batch of column vectors
input_vectors = input_vectors.view(batch_size, embed_size, 1)
# Output vectors should be a batch of row vectors
output_vectors = output_vectors.view(batch_size, 1, embed_size)
# bmm = batch matrix multiplication
# correct log-sigmoid loss
out_loss = torch.bmm(output_vectors, input_vectors).sigmoid().log()
out_loss = out_loss.squeeze()
# incorrect log-sigmoid loss
noise_loss = torch.bmm(noise_vectors.neg(), input_vectors).sigmoid().log()
noise_loss = noise_loss.squeeze().sum(1) # sum the losses over the sample of noise vectors
# negate and sum correct and noisy log-sigmoid losses
# return average batch loss
return -(out_loss + noise_loss).mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment