This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__________________________________________________________________________________________________ | |
Layer (type) Output Shape Param # Connected to | |
================================================================================================== | |
input_1 (InputLayer) [(None, 502)] 0 [] | |
input_2 (InputLayer) [(None, 502)] 0 [] | |
encoder (Encoder) (None, 502, 128) 2738688 ['input_1[0][0]'] | |
decoder (Decoder) (None, 502, 128) 4884864 ['input_2[0][0]', |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train the model | |
transformer.fit( | |
train_dataProvider, | |
validation_data=val_dataProvider, | |
epochs=configs.train_epochs, | |
callbacks=[ | |
warmupCosineDecay, | |
checkpoint, | |
tb_callback, | |
reduceLROnPlat, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define callbacks | |
warmupCosineDecay = WarmupCosineDecay( | |
lr_after_warmup=configs.lr_after_warmup, | |
final_lr=configs.final_lr, | |
warmup_epochs=configs.warmup_epochs, | |
decay_epochs=configs.decay_epochs, | |
initial_lr=configs.init_lr, | |
) | |
earlystopper = EarlyStopping(monitor="val_masked_accuracy", patience=5, verbose=1, mode="max") | |
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_masked_accuracy", verbose=1, save_best_only=True, mode="max", save_weights_only=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
optimizer = tf.keras.optimizers.Adam(learning_rate=configs.init_lr, beta_1=0.9, beta_2=0.98, epsilon=1e-9) | |
# Compile the model | |
transformer.compile( | |
loss=MaskedLoss(), | |
optimizer=optimizer, | |
metrics=[MaskedAccuracy()], | |
run_eagerly=False | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create TensorFlow Transformer Model | |
transformer = Transformer( | |
num_layers=configs.num_layers, | |
d_model=configs.d_model, | |
num_heads=configs.num_heads, | |
dff=configs.dff, | |
input_vocab_size=len(tokenizer)+1, | |
target_vocab_size=len(detokenizer)+1, | |
dropout_rate=configs.dropout_rate, | |
encoder_input_size=tokenizer.max_length, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create Training Data Provider | |
train_dataProvider = DataProvider( | |
train_dataset, | |
batch_size=configs.batch_size, | |
batch_postprocessors=[preprocess_inputs], | |
use_cache=True, | |
) | |
# Create Validation Data Provider | |
val_dataProvider = DataProvider( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def preprocess_inputs(data_batch, label_batch): | |
encoder_input = np.zeros((len(data_batch), tokenizer.max_length)).astype(np.int64) | |
decoder_input = np.zeros((len(label_batch), detokenizer.max_length)).astype(np.int64) | |
decoder_output = np.zeros((len(label_batch), detokenizer.max_length)).astype(np.int64) | |
data_batch_tokens = tokenizer.texts_to_sequences(data_batch) | |
label_batch_tokens = detokenizer.texts_to_sequences(label_batch) | |
for index, (data, label) in enumerate(zip(data_batch_tokens, label_batch_tokens)): | |
encoder_input[index][:len(data)] = data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# prepare spanish tokenizer, this is the input language | |
tokenizer = CustomTokenizer(char_level=True) | |
tokenizer.fit_on_texts(es_training_data) | |
tokenizer.save(configs.model_path + "/tokenizer.json") | |
# prepare english tokenizer, this is the output language | |
detokenizer = CustomTokenizer(char_level=True) | |
detokenizer.fit_on_texts(en_training_data) | |
detokenizer.save(configs.model_path + "/detokenizer.json") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def read_files(path): | |
with open(path, "r", encoding="utf-8") as f: | |
en_train_dataset = f.read().split("\n")[:-1] | |
return en_train_dataset | |
en_training_data = read_files(en_training_data_path) | |
en_validation_data = read_files(en_validation_data_path) | |
es_training_data = read_files(es_training_data_path) | |
es_validation_data = read_files(es_validation_data_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Path to dataset | |
en_training_data_path = "Datasets/en-es/opus.en-es-train.en" | |
en_validation_data_path = "Datasets/en-es/opus.en-es-dev.en" | |
es_training_data_path = "Datasets/en-es/opus.en-es-train.es" | |
es_validation_data_path = "Datasets/en-es/opus.en-es-dev.es" |
NewerOlder