Created
October 8, 2019 14:34
-
-
Save kabirahuja2431/72d52f45100c8e2bcada2d951e38a115 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import BertModel, BertTokenizer | |
#Creating instance of BertModel | |
bert_model = BertModel.from_pretrained('bert-base-uncased') | |
#Creating intance of tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
#Specifying the max length | |
T = 12 | |
sentence = 'I really enjoyed this movie a lot.' | |
#Step 1: Tokenize | |
tokens = tokenizer.tokenize(sentence) | |
#Step 2: Add [CLS] and [SEP] | |
tokens = ['[CLS]'] + tokens + ['[SEP]'] | |
#Step 3: Pad tokens | |
padded_tokens = tokens + ['[PAD]' for _ in range(T - len(tokens))] | |
attn_mask = [1 if token != '[PAD]' else 0 for token in padded_tokens] | |
#Step 4: Segment ids | |
seg_ids = [0 for _ in range(len(padded_tokens))] #Optional! | |
#Step 5: Get BERT vocabulary index for each token | |
token_ids = tokenizer.convert_tokens_to_ids(padded_tokens) | |
#Converting everything to torch tensors before feeding them to bert_model | |
token_ids = torch.tensor(token_ids).unsqueeze(0) #Shape : [1, 12] | |
attn_mask = torch.tensor(attn_mask).unsqueeze(0) #Shape : [1, 12] | |
seg_ids = torch.tensor(seg_ids).unsqueeze(0) #Shape : [1, 12] | |
#Feed them to bert | |
hidden_reps, cls_head = bert_model(token_ids, attention_mask = attn_mask,\ | |
token_type_ids = seg_ids) | |
print(hidden_reps.shape) | |
#Out: torch.Size([1, 12, 768]) | |
print(cls_head.shape) | |
#Out: torch.Size([1, 768]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello Kabir, Thank you for this tutorial. I've tried running this example but encountered the issue where the
hidden_reps
variable returns the last_hidden_state as a string instead of a tensor.