Last active
March 19, 2025 00:34
-
-
Save MagedSaeed/fdb22182d4ccbaff73c989ece01e0661 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Download this gist as with wget: | |
!wget -O meter_classification_model.py https://gist.githubusercontent.com/MagedSaeed/fdb22182d4ccbaff73c989ece01e0661/raw/meter_classification_model.py | |
""" | |
import torch | |
import torch.nn.functional as F | |
from huggingface_hub import PyTorchModelHubMixin | |
class MeterClassificationModel(torch.nn.Module,PyTorchModelHubMixin): | |
def __init__( | |
self, | |
device=None, | |
vocab_size=77, | |
num_layers=5, | |
gru_hiddens=256, | |
gru_dropout=0.25, | |
dropout_prob=0.333, | |
learning_rate=0.001, | |
embedding_size=256, | |
max_bait_length=128, | |
): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.num_layers = num_layers | |
self.gru_hiddens = gru_hiddens | |
self.dropout_prob = dropout_prob | |
self.learning_rate = learning_rate | |
self.embedding_size = embedding_size | |
self.meter_classes = [ | |
'البسيط', | |
'الخفيف', | |
'الرجز', | |
'الرمل', | |
'السريع', | |
'الطويل', | |
'الكامل', | |
'المتدارك', | |
'المتقارب', | |
'المجتث', | |
'المديد', | |
'المضارع', | |
'المقتضب', | |
'المنسرح', | |
'الهزج', | |
'الوافر', | |
'نثر', | |
] | |
self.number_of_classes = len(self.meter_classes) | |
self.max_bait_length = max_bait_length | |
self.class_to_meter_name = lambda meter_class: self.meter_classes[meter_class] | |
self.meter_name_to_class = lambda meter_name: self.meter_classes.index(meter_name) | |
if device is None: | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
else: | |
self.device = device | |
self.embedding_layer = torch.nn.Embedding( | |
num_embeddings=self.vocab_size, | |
embedding_dim=self.embedding_size, | |
) | |
self.gru_layer = torch.nn.GRU( | |
input_size=self.embedding_size, | |
hidden_size=self.gru_hiddens, | |
num_layers=self.num_layers, | |
dropout=gru_dropout, | |
batch_first=True, | |
bidirectional=True, | |
) | |
self.first_dense_layer = torch.nn.Linear( | |
in_features=self.gru_hiddens, | |
out_features=128, | |
) | |
self.dropout_layer = torch.nn.Dropout(p=self.dropout_prob) | |
self.relu = torch.nn.ReLU() | |
self.second_dense_layer = torch.nn.Linear( | |
in_features=128, | |
out_features=self.number_of_classes, | |
) | |
def forward(self, x, hiddens=None): | |
outputs = self.embedding_layer(x) | |
outputs, hiddens = self.gru_layer(outputs) | |
# https://stackoverflow.com/a/50914946/4412324 | |
outputs = ( | |
outputs[:, :, : self.gru_hiddens] + outputs[:, :, self.gru_hiddens :] | |
) # GRUs are bidirectional | |
outputs = self.first_dense_layer(outputs) | |
outputs = self.dropout_layer(outputs) | |
outputs = self.relu(outputs) | |
outputs = self.second_dense_layer(outputs) | |
return outputs | |
def classify(self, texts, tokenizer): | |
if isinstance(texts, str): | |
texts = [texts] | |
encoded_items = tokenizer.batch_encode( | |
texts, | |
padding=True, | |
truncation=True, | |
max_length=self.max_bait_length, | |
) | |
inputs = torch.LongTensor(encoded_items).to(self.device) | |
outputs = self(inputs) | |
outputs = outputs[:, -1, :] # take the results at the last time-step | |
outputs = F.softmax(outputs, dim=-1) | |
confidences,outputs = torch.topk(outputs,k=1, dim=-1) | |
outputs = outputs.squeeze(1).tolist() | |
confidences = confidences.squeeze(1).tolist() | |
outputs_classes = [self.class_to_meter_name(meter_class) for meter_class in outputs] | |
return outputs, outputs_classes, confidences |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment