Last active
June 13, 2018 06:50
-
-
Save alesee/1656d61f8c1ce54300ab5bafd59f6e7b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
class SimpleRNN(nn.Module): | |
""" | |
Neural Network Module with an embedding layer, a RNN module and an output layer | |
Arguments: | |
input_size(int) -- length of the dictionary of embeddings | |
embz_size(int) -- the size of each embedding vector | |
hidden_size(int) -- the number of features in the hidden state | |
output_size(int) -- the number of output classes to be predicted | |
num_layers(int, optional) -- Number of recurrent layers. Default=1 | |
Inputs: input_sequence | |
input of shape (seq_length, batch_size) -- tensor containing the features | |
of the input sequence | |
Returns: output | |
output of shape (batch_size, output_size) -- tensor containing the sigmoid | |
activation on the output features | |
h_t from the last layer of the rnn, | |
for the last time-step t. | |
""" | |
def __init__(self, input_size, embz_size, hidden_size, output_size, num_layers=1): | |
super().__init__() | |
self.embz_size, self.hidden_size = embz_size, hidden_size | |
self.output_size, self.num_layers = output_size, num_layers | |
self.embedding_layer = nn.Embedding(input_size, embz_size) | |
self.rnn = nn.RNN(embz_size, hidden_size) | |
self.output_layer = nn.Linear(hidden_size, output_size) | |
def forward(self, input_sequence): | |
batch_size = input_sequence[0].size(0) | |
hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size) | |
input_tensor = self.embedding_layer(input_sequence) | |
output, hidden = self.rnn(input_tensor, hidden) | |
output = self.output_layer(output) | |
output = F.sigmoid(output[-1, :, :]) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment