This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train(net, criterion, opti, train_loader, val_loader, args): | |
| for ep in range(args.max_eps): | |
| for it, (seq, attn_masks, labels) in enumerate(train_loader): | |
| #Clear gradients | |
| opti.zero_grad() | |
| #Converting these to cuda tensors | |
| seq, attn_masks, labels = seq.cuda(args.gpu), attn_masks.cuda(args.gpu), labels.cuda(args.gpu) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch.nn as nn | |
| import torch.optim as optim | |
| criterion = nn.BCEWithLogitsLoss() | |
| opti = optim.Adam(net.parameters(), lr = 2e-5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| net = SentimentClassifier(freeze_bert = True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch.utils.data import DataLoader | |
| #Creating instances of training and validation set | |
| train_set = SSTDataset(filename = 'data/SST-2/train.tsv', maxlen = 30) | |
| val_set = SSTDataset(filename = 'data/SST-2/dev.tsv', maxlen = 30) | |
| #Creating intsances of training and validation dataloaders | |
| train_loader = DataLoader(train_set, batch_size = 64, num_workers = 5) | |
| val_loader = DataLoader(val_set, batch_size = 64, num_workers = 5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel | |
| class SentimentClassifier(nn.Module): | |
| def __init__(self, freeze_bert = True): | |
| super(SentimentClassifier, self).__init__() | |
| #Instantiating BERT model object | |
| self.bert_layer = BertModel.from_pretrained('bert-base-uncased') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import BertTokenizer | |
| import pandas as pd | |
| class SSTDataset(Dataset): | |
| def __init__(self, filename, maxlen): | |
| #Store the contents of the file in a pandas dataframe |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| seg_ids = [0 for _ in range(len(padded_tokens))] #Since we only have a single sequence as input |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from transformers import BertModel, BertTokenizer | |
| #Creating instance of BertModel | |
| bert_model = BertModel.from_pretrained('bert-base-uncased') | |
| #Creating intance of tokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| #Specifying the max length | |
| T = 12 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import BertModel | |
| bert_model = BertModel.from_pretrained('bert-base-uncased') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Obtaining indices for each token | |
| sent_ids = tokenizer.convert_tokens_to_ids(padded_tokens) | |
| print(sent_ids) | |
| # Out: [101, 1045, 2428, 5632, 2023, 3185, 1037, 2843, 1012, 102, 0, 0] |
NewerOlder