Last active
January 17, 2022 18:32
-
-
Save remi-or/e853ed8ca4924879aeac11ce96dbcf6f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaConfig | |
def distill_roberta( | |
teacher_model : RobertaPreTrainedModel, | |
) -> RobertaPreTrainedModel: | |
""" | |
Distilates a RoBERTa (teacher_model) like would DistilBERT for a BERT model. | |
The student model has the same configuration, except for the number of hidden layers, which is // by 2. | |
The student layers are initilized by copying one out of two layers of the teacher, starting with layer 0. | |
The head of the teacher is also copied. | |
""" | |
# Get teacher configuration as a dictionnary | |
configuration = teacher_model.config.to_dict() | |
# Half the number of hidden layer | |
configuration['num_hidden_layers'] //= 2 | |
# Convert the dictionnary to the student configuration | |
configuration = RobertaConfig.from_dict(configuration) | |
# Create uninitialized student model | |
student_model = type(teacher_model)(configuration) | |
# Initialize the student's weights | |
distill_roberta_weights(teacher=teacher_model, student=student_model) | |
# Return the student model | |
return student_model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment