This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(net, criterion, opti, train_loader, val_loader, args): | |
for ep in range(args.max_eps): | |
for it, (seq, attn_masks, labels) in enumerate(train_loader): | |
#Clear gradients | |
opti.zero_grad() | |
#Converting these to cuda tensors | |
seq, attn_masks, labels = seq.cuda(args.gpu), attn_masks.cuda(args.gpu), labels.cuda(args.gpu) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.optim as optim | |
criterion = nn.BCEWithLogitsLoss() | |
opti = optim.Adam(net.parameters(), lr = 2e-5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
net = SentimentClassifier(freeze_bert = True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import DataLoader | |
#Creating instances of training and validation set | |
train_set = SSTDataset(filename = 'data/SST-2/train.tsv', maxlen = 30) | |
val_set = SSTDataset(filename = 'data/SST-2/dev.tsv', maxlen = 30) | |
#Creating intsances of training and validation dataloaders | |
train_loader = DataLoader(train_set, batch_size = 64, num_workers = 5) | |
val_loader = DataLoader(val_set, batch_size = 64, num_workers = 5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from transformers import BertModel | |
class SentimentClassifier(nn.Module): | |
def __init__(self, freeze_bert = True): | |
super(SentimentClassifier, self).__init__() | |
#Instantiating BERT model object | |
self.bert_layer = BertModel.from_pretrained('bert-base-uncased') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import Dataset | |
from transformers import BertTokenizer | |
import pandas as pd | |
class SSTDataset(Dataset): | |
def __init__(self, filename, maxlen): | |
#Store the contents of the file in a pandas dataframe |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
seg_ids = [0 for _ in range(len(padded_tokens))] #Since we only have a single sequence as input |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import BertModel, BertTokenizer | |
#Creating instance of BertModel | |
bert_model = BertModel.from_pretrained('bert-base-uncased') | |
#Creating intance of tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
#Specifying the max length | |
T = 12 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import BertModel | |
bert_model = BertModel.from_pretrained('bert-base-uncased') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Obtaining indices for each token | |
sent_ids = tokenizer.convert_tokens_to_ids(padded_tokens) | |
print(sent_ids) | |
# Out: [101, 1045, 2428, 5632, 2023, 3185, 1037, 2843, 1012, 102, 0, 0] |
NewerOlder