Skip to content

Instantly share code, notes, and snippets.

@remi-or
Created January 18, 2022 17:36
Show Gist options
  • Save remi-or/b153accc5bc46b996d5d6ee3166e568a to your computer and use it in GitHub Desktop.
Save remi-or/b153accc5bc46b996d5d6ee3166e568a to your computer and use it in GitHub Desktop.
from torch import Tensor
def get_logits(
model : RobertaPreTrainedModel,
input_ids : Tensor,
attention_mask : Tensor,
) -> Tensor:
"""
Given a RoBERTa (model) for classification and the couple of (input_ids) and (attention_mask),
returns the logits corresponding to the prediction.
"""
return model.classifier(
model.roberta(input_ids, attention_mask)[0]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment