Created
June 17, 2016 13:03
-
-
Save hiromu/cce292b0dd17331f475e5c0b72ecc6e6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From 0bf4196dfa0caba0cee28c7edc751c145c834bc1 Mon Sep 17 00:00:00 2001 | |
From: Hiromu Yakura <[email protected]> | |
Date: Tue, 24 May 2016 11:33:33 +0900 | |
Subject: [PATCH] add an option to specify the log directory | |
--- | |
main.py | 5 +++-- | |
model.py | 5 +++-- | |
2 files changed, 6 insertions(+), 4 deletions(-) | |
diff --git a/main.py b/main.py | |
index ac7aaab..0dabc38 100644 | |
--- a/main.py | |
+++ b/main.py | |
@@ -17,6 +17,7 @@ flags.DEFINE_integer("image_size", 108, "The size of image to use (will be cente | |
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") | |
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") | |
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") | |
+flags.DEFINE_string("log_dir", "logs", "Directory name to save the log files for tensorboard [logs]") | |
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]") | |
flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]") | |
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") | |
@@ -33,10 +34,10 @@ def main(_): | |
with tf.Session() as sess: | |
if FLAGS.dataset == 'mnist': | |
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, | |
- dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir) | |
+ dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, log_dir=FLAGS.log_dir) | |
else: | |
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, | |
- dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir) | |
+ dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, log_dir=FLAGS.log_dir) | |
if FLAGS.is_train: | |
dcgan.train(FLAGS) | |
diff --git a/model.py b/model.py | |
index 4dd8e1c..06d32a5 100644 | |
--- a/model.py | |
+++ b/model.py | |
@@ -11,7 +11,7 @@ class DCGAN(object): | |
batch_size=64, sample_size = 64, image_shape=[64, 64, 3], | |
y_dim=None, z_dim=100, gf_dim=64, df_dim=64, | |
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', | |
- checkpoint_dir=None): | |
+ checkpoint_dir=None, log_dir=None): | |
""" | |
Args: | |
@@ -59,6 +59,7 @@ class DCGAN(object): | |
self.dataset_name = dataset_name | |
self.checkpoint_dir = checkpoint_dir | |
+ self.log_dir = log_dir | |
self.build_model() | |
def build_model(self): | |
@@ -118,7 +119,7 @@ class DCGAN(object): | |
self.g_sum = tf.merge_summary([self.z_sum, self.d__sum, | |
self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) | |
self.d_sum = tf.merge_summary([self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) | |
- self.writer = tf.train.SummaryWriter("./logs", self.sess.graph_def) | |
+ self.writer = tf.train.SummaryWriter(self.log_dir, self.sess.graph_def) | |
sample_z = np.random.uniform(-1, 1, size=(self.sample_size , self.z_dim)) | |
sample_files = data[0:self.sample_size] | |
-- | |
2.5.4 (Apple Git-61) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment