Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save silent-vim/4b47604c537b7bd779a7543c28117016 to your computer and use it in GitHub Desktop.
Save silent-vim/4b47604c537b7bd779a7543c28117016 to your computer and use it in GitHub Desktop.
A pyTorch attention layer for torchMoji model
class Attention(Module):
"""
Computes a weighted average of channels across timesteps (1 parameter pr. channel).
"""
def __init__(self, attention_size, return_attention=False):
""" Initialize the attention layer
# Arguments:
attention_size: Size of the attention vector.
return_attention: If true, output will include the weight for each input token
used for the prediction
"""
super(Attention, self).__init__()
self.return_attention = return_attention
self.attention_size = attention_size
self.attention_vector = Parameter(torch.FloatTensor(attention_size))
def __repr__(self):
s = '{name}({attention_size}, return attention={return_attention})'
return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, inputs, input_lengths):
""" Forward pass.
# Arguments:
inputs (Torch.Variable): Tensor of input sequences
input_lengths (torch.LongTensor): Lengths of the sequences
# Return:
Tuple with (representations and attentions if self.return_attention else None).
"""
logits = inputs.matmul(self.attention_vector)
unnorm_ai = (logits - logits.max()).exp()
# Compute a mask for the attention on the padded sequences
# See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5
max_len = unnorm_ai.size(1)
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0)
if torch.cuda.is_available():
idxes = idxes.cuda()
mask = Variable((idxes < input_lengths.unsqueeze(1)).float())
# apply mask and renormalize attention scores (weights)
masked_weights = unnorm_ai * mask
att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence
attentions = masked_weights.div(att_sums)
# apply attention weights
weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))
# get the final fixed vector representations of the sentences
representations = weighted.sum(dim=1)
return (representations, attentions if self.return_attention else None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment