Skip to content

Instantly share code, notes, and snippets.

@silent-vim
Created September 23, 2017 23:13
Show Gist options
  • Save silent-vim/324e0b7e49e67069478958dc67614ea8 to your computer and use it in GitHub Desktop.
Save silent-vim/324e0b7e49e67069478958dc67614ea8 to your computer and use it in GitHub Desktop.
pytorch attentional LSTM cell
class attentionalLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, num_variants):
super(attentionalLSTMCell, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.num_variants = num_variants
self.ih = nn.Linear(self.input_size, 4 * self.hidden_size * self.num_variants)
self.hh = nn.Linear(self.hidden_size, 4 * self.hidden_size * self.num_variants)
self.hhh = nn.Linear(self.hidden_size, self.num_variants)
def forward(self, input, hidden):
hx, cx = hidden
gates = self.ih(input) + self.hh(hx)
gates_weights = F.softmax(self.hhh(cx))
gates = gates.view(-1, 4 * self.hidden_size, self.num_variants)
gates = torch.matmul(gates, gates_weights.squeeze())
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment