Last active
April 22, 2019 06:05
-
-
Save anna-hope/59415d349e8755b105e8aa21e803ef9b to your computer and use it in GitHub Desktop.
Structured Self-Attention in PyTorch (Lin et al. 2017)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Implementation of Structured Self-Attention mechanism | |
# from Lin et al. 2017 (https://arxiv.org/pdf/1703.03130.pdf) | |
# Anton Melnikov | |
import torch | |
import torch.nn as nn | |
class StructuredAttention(nn.Module): | |
def __init__(self, *, input_dim: int, hidden_dim: int, attention_hops: int): | |
super().init() | |
self.w1 = nn.Parameter(torch.randn(size=(hidden_dim, input_dim))) | |
self.w2 = nn.Parameter(torch.randn(attention_hops, hidden_dim)) | |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
X = self.w1 @ hidden_states.transpose(2, 1) | |
X = torch.tanh(X) | |
a = torch.softmax((self.w2 @ X), dim=-1) | |
m = a @ hidden_states | |
return m, a | |
def get_attention_penalty(attention_matrix: torch.Tensor): | |
identity = torch.eye(attention_matrix.shape[1]) | |
p = attention_matrix @ attention_matrix.transpose(2, 1) - identity | |
p = torch.norm(p) | |
return p |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment