Created
November 30, 2022 05:05
-
-
Save insaneyilin/1080d7996ab13809a8e49bab6106b07f to your computer and use it in GitHub Desktop.
log_softmax and logsumexp with mask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def masked_log_softmax(input, mask, dim=1): | |
masked_input = input * mask.float() | |
max_input = torch.max(masked_input, dim=dim, keepdim=True)[0] | |
exps = torch.exp(masked_input - max_input) | |
masked_exps = exps * mask.float() | |
masked_sums = masked_exps.sum(dim, keepdim=True) | |
zeros = (masked_sums == 0) | |
masked_sums += zeros.float() | |
masked_exps += 1e-6 # avoid zero input of log. | |
return torch.log(masked_exps / masked_sums) | |
def masked_log_sum_exp(input, keepdim=False, mask=None): | |
"""Numerically stable logsumexp on the last dim of `input`. | |
reference: https://github.com/pytorch/pytorch/issues/2591 | |
Args: | |
input: A Variable with any shape. | |
keepdim: A boolean. | |
mask: A mask variable of type float. It has the same shape as `input`. | |
Valid entries are masked to ones. | |
Returns: | |
Equivalent of log(sum(exp(input), keepdim=keepdim)). | |
""" | |
if mask is not None: | |
mask = 1. - mask | |
max_offset = -1e7 * mask | |
else: | |
max_offset = 0. | |
s, _ = torch.max(input + max_offset, dim=-1, keepdim=True) | |
input_offset = input - s | |
if mask is not None: | |
input_offset.masked_fill_((mask > 1e-6), -float('inf')) | |
output = s + input_offset.exp().sum(dim=-1, keepdim=True).log() | |
if not keepdim: | |
output = output.squeeze(-1) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment