Created
April 23, 2021 18:50
-
-
Save jobergum/6797a4421596b6c1a2fba76a50cdff64 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import BertPreTrainedModel | |
from transformers import BertModel | |
class SentenceEncoder(BertPreTrainedModel): | |
def __init__(self,config): | |
super().__init__(config) | |
self.bert = BertModel(config) | |
self.init_weights() | |
def forward(self, input_ids, attention_mask): | |
model_output = self.bert(input_ids,attention_mask=attention_mask) | |
token_embeddings = model_output[0] #First element of model_output contains all token embeddings | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
return sum_embeddings / sum_mask | |
s = SentenceEncoder.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens") | |
input_names = ["input_ids", "attention_mask"] | |
output_names = ["contextual"] | |
#input, max 32 query term | |
input_ids = torch.ones(1,32, dtype=torch.int64) | |
attention_mask = torch.ones(1,32,dtype=torch.int64) | |
args = (input_ids, attention_mask) | |
torch.onnx.export(s, | |
args=args, | |
f="sentence_mean_encoder.onnx", | |
input_names = input_names, | |
output_names = output_names, | |
dynamic_axes = { | |
"input_ids": {0: "batch"}, | |
"attention_mask": {0: "batch"}, | |
"contextual": {0: "batch"}, | |
},opset_version=11) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment