This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def aug_test(self, | |
imgs: List[Tensor], | |
img_metas: List[dict], | |
rescale: bool = False) -> Tensor: | |
acc_boxes = np.zeros((0, 5)) | |
acc_score = np.zeros((0, self.roi_head.bbox_head.num_classes)) | |
for img, img_meta in zip(imgs, img_metas): | |
for label, dets in enumerate(self.simple_test(img, img_meta, None, rescale)[0]): | |
boxes, scores = dets[:, :-1], dets[:, -1] | |
acc_boxes = np.vstack((acc_boxes, boxes)) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Imports | |
from typing import Tuple | |
import torch | |
from torch import Module, Tensor | |
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaConfig, RobertaModel, RobertaEncoder | |
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss | |
## Function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss | |
def distillation_loss( | |
teacher_logits : Tensor, | |
student_logits : Tensor, | |
labels : Tensor, | |
temperature : float = 1.0, | |
) -> Tensor: | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import Tensor | |
def get_logits( | |
model : RobertaPreTrainedModel, | |
input_ids : Tensor, | |
attention_mask : Tensor, | |
) -> Tensor: | |
""" | |
Given a RoBERTa (model) for classification and the couple of (input_ids) and (attention_mask), | |
returns the logits corresponding to the prediction. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers.models.roberta.modeling_roberta import RobertaEncoder, RobertaModel | |
from torch.nn import Module | |
def distill_roberta_weights( | |
teacher : Module, | |
student : Module, | |
) -> None: | |
""" | |
Recursively copies the weights of the (teacher) to the (student). | |
This function is meant to be first called on a RobertaFor... model, but is then called on every children of that model recursively. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaConfig | |
def distill_roberta( | |
teacher_model : RobertaPreTrainedModel, | |
) -> RobertaPreTrainedModel: | |
""" | |
Distilates a RoBERTa (teacher_model) like would DistilBERT for a BERT model. | |
The student model has the same configuration, except for the number of hidden layers, which is // by 2. | |
The student layers are initilized by copying one out of two layers of the teacher, starting with layer 0. | |
The head of the teacher is also copied. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
from transformers import AutoModelForMaskedLM | |
roberta = AutoModelForMaskedLM.from_pretrained("roberta-large") | |
def visualize_children( | |
object : Any, | |
level : int = 0, | |
) -> None: | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForMaskedLM | |
roberta = AutoModelForMaskedLM.from_pretrained("roberta-large") | |
print(roberta) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import seaborn as sns | |
def average_word_count(list_of_texts): | |
""" | |
Returns the average word count of a list of texts. | |
""" | |
total_count = 0 | |
for text in list_of_texts: | |
text = text.replace("'", ' ') |
NewerOlder