Last active
August 31, 2021 18:02
-
-
Save joelthchao/ef6caa586b647c3c032a4f84d52e3a11 to your computer and use it in GitHub Desktop.
Keras uses TensorBoard Callback with train_on_batch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from keras.callbacks import TensorBoard | |
from keras.layers import Input, Dense | |
from keras.models import Model | |
def write_log(callback, names, logs, batch_no): | |
for name, value in zip(names, logs): | |
summary = tf.Summary() | |
summary_value = summary.value.add() | |
summary_value.simple_value = value | |
summary_value.tag = name | |
callback.writer.add_summary(summary, batch_no) | |
callback.writer.flush() | |
net_in = Input(shape=(3,)) | |
net_out = Dense(1)(net_in) | |
model = Model(net_in, net_out) | |
model.compile(loss='mse', optimizer='sgd', metrics=['mae']) | |
log_path = './logs' | |
callback = TensorBoard(log_path) | |
callback.set_model(model) | |
train_names = ['train_loss', 'train_mae'] | |
val_names = ['val_loss', 'val_mae'] | |
for batch_no in range(100): | |
X_train, Y_train = np.random.rand(32, 3), np.random.rand(32, 1) | |
logs = model.train_on_batch(X_train, Y_train) | |
write_log(callback, train_names, logs, batch_no) | |
if batch_no % 10 == 0: | |
X_val, Y_val = np.random.rand(32, 3), np.random.rand(32, 1) | |
logs = model.train_on_batch(X_val, Y_val) | |
write_log(callback, val_names, logs, batch_no//10) |
Here's a simpler solution, which uses the TensorBoard callback directly:
https://gist.github.com/erenon/91f526302cd8e9d21b73f24c0f9c4bb8
thanks bro, it really helped
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello, I am trying to upgrade my code to Tensorflow 2.0, it does not have tf.Summary().
So could anyone tell me how to create the write_log function which will allow me to visualize using tensorboard
Thank you