Created
May 5, 2020 19:02
-
-
Save a7v8x/2a2ed0ea8deebabf663520a8f47979cf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import TFBertForSequenceClassification | |
import tensorflow as tf | |
# recommended learning rate for Adam 5e-5, 3e-5, 2e-5 | |
learning_rate = 2e-5 | |
# we will do just 1 epoch for illustration, though multiple epochs might be better as long as we will not overfit the model | |
number_of_epochs = 1 | |
# model initialization | |
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased') | |
# optimizer Adam recommended | |
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08) | |
# we do not have one-hot vectors, we can use sparce categorical cross entropy and accuracy | |
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') | |
model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment