Skip to content

Instantly share code, notes, and snippets.

@remi-or
Last active January 17, 2022 18:32
Show Gist options
  • Save remi-or/e853ed8ca4924879aeac11ce96dbcf6f to your computer and use it in GitHub Desktop.
Save remi-or/e853ed8ca4924879aeac11ce96dbcf6f to your computer and use it in GitHub Desktop.
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaConfig
def distill_roberta(
teacher_model : RobertaPreTrainedModel,
) -> RobertaPreTrainedModel:
"""
Distilates a RoBERTa (teacher_model) like would DistilBERT for a BERT model.
The student model has the same configuration, except for the number of hidden layers, which is // by 2.
The student layers are initilized by copying one out of two layers of the teacher, starting with layer 0.
The head of the teacher is also copied.
"""
# Get teacher configuration as a dictionnary
configuration = teacher_model.config.to_dict()
# Half the number of hidden layer
configuration['num_hidden_layers'] //= 2
# Convert the dictionnary to the student configuration
configuration = RobertaConfig.from_dict(configuration)
# Create uninitialized student model
student_model = type(teacher_model)(configuration)
# Initialize the student's weights
distill_roberta_weights(teacher=teacher_model, student=student_model)
# Return the student model
return student_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment