Created
October 22, 2019 00:41
-
-
Save VictorSanh/84333b5ed0f3673d30be20737d7a1be7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision.models as models | |
resnet18 = models.resnet18() | |
from transformers import BertEmbeddings, BertEncoder | |
class MMBDEmbeddings(nn.Module): | |
def __init__(self, | |
text_mod_embds = BertEmbeddings, # Or your favorite bidirectional transformer | |
vision_mod_embds = resnet18): # Or your favorite vision model | |
super(MMBDEmbeddings, self).__init__() | |
self.text_mod_embds = text_mod_embds | |
self.vision_mod_embds = vision_mod_embds | |
self.vision_to_text_proj = nn.Linear(vision_dim, text_dim) | |
def forward(self, input_ids, images): | |
image_embs = self.vision_mod_embds(images) | |
proj_image_embds = self.vision_to_text_proj(image_embds) | |
token_embds = self.text_mod_embds(input_ids) | |
rslt = {'image': proj_image_embds, 'text': token_embds} | |
return rslt | |
class MMBDModel(nn.Module): | |
def __init__(self, | |
embeddings, # a MMBDEmbeddings object | |
encoder, # BertEncoder for instance | |
pooler): # It can be as simplest as take the [CLS] hiddens state. | |
super(MMBDModel, self).__init__() | |
self.embeddings = embeddings | |
self.encoder = encoder | |
def forward(self, | |
input_ids, | |
images): # and other arguments such as attention_mask, token_type_ids, etc. (see the encoder) | |
embings = self.embeddings(inputs_ids, images) | |
# do the concatenation of the two sequences of embeddings --> embds_seq | |
hidden_states = self.encoder(embds_seq) | |
pooled_output = self.pooler(hidden_states) | |
outputs = (hidden_states, pooled_output) | |
return outputs | |
class MMDBForMultiModalClassification(nn.Module): | |
def __init__(self, | |
mmdb_model): | |
self.classification_head = nn.Linear(768, 2) # for instance for a binary classification | |
self.mmdb_model = mmdb_model | |
def forward(self, | |
input_ids, | |
images): | |
_, pooled_output = self.mmdb_model(input_ids, images) | |
return self.classification_head(pooled_output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment