-
-
Save aus10powell/8a865786f3f195869f1529115f10751f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
from torch.nn import functional as F | |
class TextCNN(nn.Module): | |
def __init__(self, batch_size, output_size, in_channels, out_channels, kernel_heights, | |
stride, padding, keep_probab, vocab_size, embedding_dim, weights): | |
super(TextCNN, self).__init__() | |
""" | |
Arguments | |
--------- | |
batch_size : Size of each batch which is same as the batch_size of the data returned by the TorchText BucketIterator | |
output_size : Number of labels | |
in_channels : Number of input channels. Here it is 1 as the input data has dimension = (batch_size, num_seq, embedding_length) | |
out_channels : Number of output channels after convolution operation performed on the input matrix | |
kernel_heights : A list consisting of 3 different kernel_heights. Convolution will be performed 3 times and finally results from each kernel_height will be concatenated. | |
stride: The number of tokens that the slide conv window moves over for next input | |
padding: | |
keep_probab : Probability of retaining an activation node during dropout operation | |
vocab_size : Size of the vocabulary containing unique words | |
embedding_dim : Embedding dimension of GloVe word embeddings | |
weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table | |
""" | |
self.batch_size = batch_size | |
self.output_size = output_size | |
self.out_channels = out_channels | |
self.kernel_heights = kernel_heights | |
self.stride = stride | |
self.padding = padding | |
self.vocab_size = vocab_size | |
self.embedding_dim = embedding_dim | |
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) | |
self.word_embeddings.weight = nn.Parameter(weights, requires_grad=False) | |
# Define convlutions | |
self.conv1 = nn.Conv2d(in_channels, out_channels, (kernel_heights[0], embedding_dim), stride, padding) | |
self.conv2 = nn.Conv2d(in_channels, out_channels, (kernel_heights[1], embedding_dim), stride, padding) | |
self.conv3 = nn.Conv2d(in_channels, out_channels, (kernel_heights[2], embedding_dim), stride, padding) | |
self.dropout = nn.Dropout(keep_probab) | |
self.label = nn.Linear(len(kernel_heights)*out_channels, output_size) | |
def conv_block(self, input, conv_layer): | |
""" | |
Parameters | |
---------- | |
input: Batch of tokens | |
conv_layer: convolution layer to be applied | |
Returns | |
------- | |
Outputs fully connected max pooling on a layer and filter maximum activation. | |
""" | |
conv_out = conv_layer(input) # (batch_size, out_channels, dim, 1) | |
activation = F.relu(conv_out.squeeze(3)) # (batch_size, out_channels, dim1) | |
max_out = F.max_pool1d(activation, activation.size()[2]).squeeze(2) # (batch_size, out_channels) | |
return max_out | |
def forward(self, input_text, batch_size=None): | |
""" | |
Define how model is going to be run from input to output. | |
Parameters | |
---------- | |
input_text: input_sentences of shape = (batch_size, num_sequences) | |
batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1) | |
Returns | |
------- | |
Output of the linear layer containing logits for pos & neg class. | |
logits.size() = (batch_size, output_size) | |
""" | |
data_in = self.word_embeddings(input_text) # (batch_size, num_seq, embedding_length) | |
data_in = data_in.unsqueeze(1) # (batch_size, 1, num_seq, embedding_length) | |
max_out1 = self.conv_block(data_in, self.conv1) | |
max_out2 = self.conv_block(data_in, self.conv2) | |
max_out3 = self.conv_block(data_in, self.conv3) | |
out = torch.cat((max_out1, max_out2, max_out3), 1) # (batch_size, num_kernels*out_channels) | |
out = self.dropout(out) # (batch_size, num_kernels*out_channels) | |
logits = self.label(out) | |
return logits # (len(kernel_heights)*out_channels, output_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment