Created
July 5, 2020 10:24
-
-
Save solaris33/f567e48574200c89da3ce1266d3df998 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import os | |
import time | |
from matplotlib import pyplot as plt | |
from absl import app | |
import datetime | |
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz' | |
path_to_zip = tf.keras.utils.get_file('facades.tar.gz', | |
origin=_URL, | |
extract=True) | |
PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/') | |
BUFFER_SIZE = 400 | |
BATCH_SIZE = 1 | |
IMG_WIDTH = 256 | |
IMG_HEIGHT = 256 | |
OUTPUT_CHANNELS = 3 | |
LAMBDA = 100 | |
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True) | |
EPOCHS = 150 | |
def load(image_file): | |
image = tf.io.read_file(image_file) | |
image = tf.image.decode_jpeg(image) | |
w = tf.shape(image)[1] | |
w = w // 2 | |
real_image = image[:, :w, :] | |
input_image = image[:, w:, :] | |
input_image = tf.cast(input_image, tf.float32) | |
real_image = tf.cast(real_image, tf.float32) | |
return input_image, real_image | |
def resize(input_image, real_image, height, width): | |
input_image = tf.image.resize(input_image, [height, width], | |
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) | |
real_image = tf.image.resize(real_image, [height, width], | |
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) | |
return input_image, real_image | |
def random_crop(input_image, real_image): | |
stacked_image = tf.stack([input_image, real_image], axis=0) | |
cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3]) | |
return cropped_image[0], cropped_image[1] | |
@tf.function() | |
def random_jitter(input_image, real_image): | |
# resizing to 286 x 286 x 3 | |
input_image, real_image = resize(input_image, real_image, 286, 286) | |
# randomly cropping to 256 x 256 x 3 | |
input_image, real_image = random_crop(input_image, real_image) | |
if tf.random.uniform(()) > 0.5: | |
# random mirroring | |
input_image = tf.image.flip_left_right(input_image) | |
real_image = tf.image.flip_left_right(real_image) | |
return input_image, real_image | |
# normalizing the images to [-1, 1] | |
def normalize(input_image, real_image): | |
input_image = (input_image / 127.5) - 1 | |
real_image = (real_image / 127.5) - 1 | |
return input_image, real_image | |
def load_image_train(image_file): | |
input_image, real_image = load(image_file) | |
input_image, real_image = random_jitter(input_image, real_image) | |
input_image, real_image = normalize(input_image, real_image) | |
return input_image, real_image | |
def load_image_test(image_file): | |
input_image, real_image = load(image_file) | |
input_image, real_image = resize(input_image, real_image, | |
IMG_HEIGHT, IMG_WIDTH) | |
input_image, real_image = normalize(input_image, real_image) | |
return input_image, real_image | |
def downsample(filters, size, apply_batchnorm=True): | |
initializer = tf.random_normal_initializer(0., 0.02) | |
result = tf.keras.Sequential() | |
result.add( | |
tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', | |
kernel_initializer=initializer, use_bias=False)) | |
if apply_batchnorm: | |
result.add(tf.keras.layers.BatchNormalization()) | |
result.add(tf.keras.layers.LeakyReLU()) | |
return result | |
def upsample(filters, size, apply_dropout=False): | |
initializer = tf.random_normal_initializer(0., 0.02) | |
result = tf.keras.Sequential() | |
result.add( | |
tf.keras.layers.Conv2DTranspose(filters, size, strides=2, | |
padding='same', | |
kernel_initializer=initializer, | |
use_bias=False)) | |
result.add(tf.keras.layers.BatchNormalization()) | |
if apply_dropout: | |
result.add(tf.keras.layers.Dropout(0.5)) | |
result.add(tf.keras.layers.ReLU()) | |
return result | |
def Generator(): | |
inputs = tf.keras.layers.Input(shape=[256,256,3]) | |
down_stack = [ | |
downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64) | |
downsample(128, 4), # (bs, 64, 64, 128) | |
downsample(256, 4), # (bs, 32, 32, 256) | |
downsample(512, 4), # (bs, 16, 16, 512) | |
downsample(512, 4), # (bs, 8, 8, 512) | |
downsample(512, 4), # (bs, 4, 4, 512) | |
downsample(512, 4), # (bs, 2, 2, 512) | |
downsample(512, 4), # (bs, 1, 1, 512) | |
] | |
up_stack = [ | |
upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024) | |
upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024) | |
upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024) | |
upsample(512, 4), # (bs, 16, 16, 1024) | |
upsample(256, 4), # (bs, 32, 32, 512) | |
upsample(128, 4), # (bs, 64, 64, 256) | |
upsample(64, 4), # (bs, 128, 128, 128) | |
] | |
initializer = tf.random_normal_initializer(0., 0.02) | |
last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, | |
strides=2, | |
padding='same', | |
kernel_initializer=initializer, | |
activation='tanh') # (bs, 256, 256, 3) | |
x = inputs | |
# Downsampling through the model | |
skips = [] | |
for down in down_stack: | |
x = down(x) | |
skips.append(x) | |
skips = reversed(skips[:-1]) | |
# Upsampling and establishing the skip connections | |
for up, skip in zip(up_stack, skips): | |
x = up(x) | |
x = tf.keras.layers.Concatenate()([x, skip]) | |
x = last(x) | |
return tf.keras.Model(inputs=inputs, outputs=x) | |
def Discriminator(): | |
initializer = tf.random_normal_initializer(0., 0.02) | |
inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image') | |
tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image') | |
x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2) | |
down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64) | |
down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128) | |
down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256) | |
zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256) | |
conv = tf.keras.layers.Conv2D(512, 4, strides=1, | |
kernel_initializer=initializer, | |
use_bias=False)(zero_pad1) # (bs, 31, 31, 512) | |
batchnorm1 = tf.keras.layers.BatchNormalization()(conv) | |
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1) | |
zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512) | |
last = tf.keras.layers.Conv2D(1, 4, strides=1, | |
kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1) | |
return tf.keras.Model(inputs=[inp, tar], outputs=last) | |
def generator_loss(disc_generated_output, gen_output, target): | |
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output) | |
# mean absolute error | |
l1_loss = tf.reduce_mean(tf.abs(target - gen_output)) | |
total_gen_loss = gan_loss + (LAMBDA * l1_loss) | |
return total_gen_loss, gan_loss, l1_loss | |
def discriminator_loss(disc_real_output, disc_generated_output): | |
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output) | |
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output) | |
total_disc_loss = real_loss + generated_loss | |
return total_disc_loss | |
def generate_images(model, test_input, tar): | |
prediction = model(test_input, training=True) | |
plt.figure(figsize=(15,15)) | |
display_list = [test_input[0], tar[0], prediction[0]] | |
title = ['Input Image', 'Ground Truth', 'Predicted Image'] | |
for i in range(3): | |
plt.subplot(1, 3, i+1) | |
plt.title(title[i]) | |
# getting the pixel values between [0, 1] to plot it. | |
plt.imshow(display_list[i] * 0.5 + 0.5) | |
plt.axis('off') | |
plt.show() | |
@tf.function | |
def train_step(input_image, target, epoch, generator, discriminator, generator_optimizer, discriminator_optimizer, summary_writer): | |
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: | |
gen_output = generator(input_image, training=True) | |
disc_real_output = discriminator([input_image, target], training=True) | |
disc_generated_output = discriminator([input_image, gen_output], training=True) | |
gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target) | |
disc_loss = discriminator_loss(disc_real_output, disc_generated_output) | |
generator_gradients = gen_tape.gradient(gen_total_loss, | |
generator.trainable_variables) | |
discriminator_gradients = disc_tape.gradient(disc_loss, | |
discriminator.trainable_variables) | |
generator_optimizer.apply_gradients(zip(generator_gradients, | |
generator.trainable_variables)) | |
discriminator_optimizer.apply_gradients(zip(discriminator_gradients, | |
discriminator.trainable_variables)) | |
with summary_writer.as_default(): | |
tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch) | |
tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch) | |
tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch) | |
tf.summary.scalar('disc_loss', disc_loss, step=epoch) | |
def main(_): | |
# check data structure | |
inp, re = load(PATH+'train/100.jpg') | |
# casting to int for matplotlib to show the image | |
plt.figure() | |
plt.imshow(inp/255.0) | |
plt.show() | |
plt.figure() | |
plt.imshow(re/255.0) | |
plt.show() | |
# show random jittered image | |
plt.figure(figsize=(6, 6)) | |
for i in range(4): | |
rj_inp, rj_re = random_jitter(inp, re) | |
plt.subplot(2, 2, i+1) | |
plt.imshow(rj_inp/255.0) | |
plt.axis('off') | |
plt.show() | |
# set train batch | |
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg') | |
train_dataset = train_dataset.map(load_image_train, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
train_dataset = train_dataset.shuffle(BUFFER_SIZE) | |
train_dataset = train_dataset.batch(BATCH_SIZE) | |
# set test batch | |
test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg') | |
test_dataset = test_dataset.map(load_image_test) | |
test_dataset = test_dataset.batch(BATCH_SIZE) | |
# test downsample and upsample | |
down_model = downsample(3, 4) | |
down_result = down_model(tf.expand_dims(inp, 0)) | |
print (down_result.shape) | |
assert down_result.shape == (1, 128, 128, 3) | |
up_model = upsample(3, 4) | |
up_result = up_model(down_result) | |
print (up_result.shape) | |
assert up_result.shape == (1, 256, 256, 3) | |
generator = Generator() | |
tf.keras.utils.plot_model(generator, to_file='generator.png', show_shapes=True, dpi=64) | |
# show generated output without training | |
gen_output = generator(inp[tf.newaxis,...], training=False) | |
plt.imshow(gen_output[0,...]) | |
plt.show() | |
discriminator = Discriminator() | |
tf.keras.utils.plot_model(discriminator, to_file='discriminator.png', show_shapes=True, dpi=64) | |
disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False) | |
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r') | |
plt.colorbar() | |
# set optimizer | |
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) | |
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) | |
# set checkpoint | |
checkpoint_dir = './training_checkpoints' | |
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") | |
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, | |
discriminator_optimizer=discriminator_optimizer, | |
generator=generator, | |
discriminator=discriminator) | |
# generate_images function test | |
for example_input, example_target in test_dataset.take(1): | |
generate_images(generator, example_input, example_target) | |
# set tensorboard filewriter | |
log_dir="logs/" | |
summary_writer = tf.summary.create_file_writer( | |
log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) | |
# Train | |
for epoch in range(EPOCHS): | |
start = time.time() | |
print("Epoch: ", epoch) | |
for n, (input_image, target) in train_dataset.enumerate(): | |
print('.', end='') | |
if (n+1) % 100 == 0: | |
print() | |
train_step(input_image, target, epoch, generator, discriminator, generator_optimizer, discriminator_optimizer, summary_writer) | |
print() | |
# saving (checkpoint) the model every 20 epochs | |
if (epoch + 1) % 20 == 0: | |
checkpoint.save(file_prefix = checkpoint_prefix) | |
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, | |
time.time()-start)) | |
checkpoint.save(file_prefix = checkpoint_prefix) | |
# restoring the latest checkpoint in checkpoint_dir | |
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) | |
# Run the trained model on a few examples from the test dataset | |
for inp, tar in test_dataset.take(5): | |
generate_images(generator, inp, tar) | |
if __name__ == '__main__': | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment