Created
July 8, 2018 14:20
-
-
Save phizaz/ab3cbcfe5788dd22fe4efba89b241424 to your computer and use it in GitHub Desktop.
Tensorflow summary v2 graph
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
save_path = 'summary_path' | |
graph = tf.Graph() | |
with graph.as_default(): | |
global_step = tf.train.create_global_step() | |
writer = tf.contrib.summary.create_file_writer(save_path) | |
with writer.as_default(): | |
tf.contrib.summary.always_record_summaries() | |
# simulate dataset | |
fake_dataset = np.random.randn(1000, 100).astype(np.float32) | |
fake_label = np.random.randint(low=0, high=9, size=1000) | |
# preparing a fake dataset | |
with graph.as_default(): | |
x = tf.data.Dataset.from_tensor_slices(fake_dataset) | |
y = tf.data.Dataset.from_tensor_slices(fake_label) | |
data = tf.data.Dataset.zip((x, y)) | |
data = data.shuffle(10000) | |
data = data.batch(32) | |
data_itr = data.make_initializable_iterator() | |
x, y = data_itr.get_next() | |
# define the computing graph | |
with graph.as_default(): | |
with writer.as_default(): | |
# construct a simple classifier | |
net = tf.keras.Sequential([ | |
tf.keras.layers.Dense(300, activation=tf.nn.relu), | |
tf.keras.layers.Dense(10) | |
]) | |
prediction_op = net(x) | |
loss_op = tf.losses.sparse_softmax_cross_entropy(y, prediction_op) | |
opt_op = tf.train.AdamOptimizer(0.001).minimize( | |
loss_op, | |
global_step=tf.train.get_global_step()) | |
# here is how you log every step (n=1) | |
with tf.contrib.summary.record_summaries_every_n_global_steps(1): | |
tf.contrib.summary.scalar('loss', loss_op) | |
summary_op = tf.contrib.summary.all_summary_ops() | |
# compute the graph | |
with graph.as_default(): | |
with writer.as_default(): | |
with tf.Session() as sess: | |
# initialize the summary writer | |
tf.contrib.summary.initialize( | |
graph=tf.get_default_graph() | |
) | |
# init vars | |
sess.run(tf.global_variables_initializer()) | |
# init iterator | |
sess.run(data_itr.initializer) | |
# run until the dataset is exhausted | |
while True: | |
try: | |
_, _, loss = sess.run([ | |
summary_op, opt_op, loss_op | |
]) | |
except tf.errors.OutOfRangeError: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment