Created
December 18, 2018 16:35
-
-
Save beomjunshin-ben/641a7306279e5b052fb407115b08dd0d to your computer and use it in GitHub Desktop.
tf.dataset tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import tensorflow as tf | |
import numpy as np | |
class DatasetTutorial(): | |
def __init__(self): | |
self.num_samples = 100 | |
self.batch_size = 10 | |
self.repeat = 2 | |
def _parse_function(self, image, label): | |
image *= 1 | |
label *= 1 | |
return image, label | |
def ours(self): | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples) | |
self.dataset = self.dataset.batch(self.batch_size) | |
def case1(self): | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples) | |
self.dataset = self.dataset.batch(self.batch_size) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
def case2(self): | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.batch(self.batch_size) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
def case3(self): | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 2) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.batch(self.batch_size) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
def case4(self): | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 4) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.batch(self.batch_size) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
def case5(self): | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 5) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.batch(self.batch_size) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
def case6(self): | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples) | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.batch(self.batch_size) | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 2) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
def case7(self): | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.batch(self.batch_size) | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 2) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
def case8(self): | |
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples) | |
self.dataset = self.dataset.map(self._parse_function) | |
self.dataset = self.dataset.batch(self.batch_size) | |
self.dataset = self.dataset.repeat(self.repeat) | |
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) | |
def run(self, func): | |
TF_SESSION_CONFIG = tf.ConfigProto( | |
gpu_options=tf.GPUOptions(allow_growth=True), | |
log_device_placement=False, | |
device_count={"GPU": 1}) | |
grid_size = 2 | |
fig1, axes1 = plt.subplots(nrows=grid_size, ncols=grid_size, figsize=(8, 8)) | |
fig2, axes2 = plt.subplots(nrows=grid_size, ncols=grid_size, figsize=(8, 8)) | |
for n in range(grid_size ** 2): | |
nrow = n % grid_size | |
ncol = n // grid_size | |
print(nrow, ncol) | |
with tf.Session(config=TF_SESSION_CONFIG) as sess: | |
images = tf.constant(np.arange(0, self.num_samples, 1)) | |
labels = tf.constant(np.arange(0, self.num_samples, 1)) | |
self.dataset = tf.data.Dataset.from_tensor_slices((images, labels)) | |
eval("self." + func)() | |
iterator = self.dataset.make_one_shot_iterator() | |
next_element = iterator.get_next() | |
assert self.num_samples % self.batch_size == 0 | |
niter = self.num_samples // self.batch_size * self.repeat | |
image_footprints = np.zeros(shape=(self.num_samples, niter)) | |
label_footprints = np.zeros(shape=(self.num_samples, niter)) | |
for i in range(niter): | |
image_batch, label_batch = sess.run(next_element) | |
image_footprints[image_batch, i] = 1 | |
label_footprints[label_batch, i] = 1 | |
np.testing.assert_equal(image_footprints, label_footprints) | |
sns.heatmap(image_footprints, ax=axes1[nrow, ncol]) | |
sns.heatmap(image_footprints.cumsum(axis=1), ax=axes2[nrow, ncol]) | |
fig1.tight_layout() | |
fig1.savefig(f"{func}_1.png") | |
fig2.tight_layout() | |
fig2.savefig(f"{func}_2.png") | |
parser = argparse.ArgumentParser(description=__doc__) | |
parser.add_argument('--method', default="ours", type=str) | |
parser.add_argument('--total_steps', default="ours", type=str) | |
args = parser.parse_args() | |
dataset = DatasetTutorial() | |
dataset.run(args.method) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
갓갓