Last active
July 7, 2017 23:09
-
-
Save rbrigden/7a1d0b853139d0c2cd19b74b39064f95 to your computer and use it in GitHub Desktop.
autoencoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division, print_function, absolute_import | |
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Import MNIST data | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets("MNIST_data", one_hot=False) | |
EPOCHS = 150 | |
BATCH_SIZE = 100 | |
EPOCH_SIZE = 50000 | |
GAMMA = 0.00008 | |
tf.app.flags.DEFINE_boolean("train", False , "train if true, eval if false") | |
tf.app.flags.DEFINE_boolean("plot_point", False, "plot points") | |
FLAGS = tf.app.flags.FLAGS | |
def _truncated_normal_initializer(stddev=0.1): | |
def _initializer(shape, dtype=tf.float32, partition_info=None): | |
return tf.truncated_normal(shape, stddev=stddev, dtype=dtype) | |
return _initializer | |
def _variable_on_cpu(name, shape, initializer=None, reuse=None): | |
with tf.device("/cpu:0"): | |
dtype = "float32" | |
with tf.variable_scope(name, reuse=reuse): | |
return tf.get_variable(name, shape=shape, initializer=initializer) | |
def _affine(x, weights_shape=None, bias_shape=None, decay=None, name=None): | |
W = _variable_on_cpu("weights", | |
shape=weights_shape, | |
initializer=_truncated_normal_initializer()) | |
b = _variable_on_cpu("biases", | |
shape=bias_shape, | |
initializer=tf.constant_initializer(0.0)) | |
# add L2 regularization | |
if decay is not None: | |
wd = tf.multiply(decay, tf.nn.l2_loss(W), name="l2_weight_loss") | |
tf.add_to_collection("losses", wd) | |
return tf.add(tf.matmul(x, W), b, name=name) | |
def _conv_2d(x, kernel_shape, bias_shape, stride=1, padding="SAME", | |
decay=None, name=None, reuse=None): | |
W = _variable_on_cpu(name="weights", | |
shape=kernel_shape, | |
initializer=_truncated_normal_initializer(), | |
reuse=reuse) | |
b = _variable_on_cpu(name="biases", | |
shape=bias_shape, | |
initializer=tf.constant_initializer(0.0), | |
reuse=reuse) | |
# add L2 regularization | |
if decay is not None: | |
wd = tf.multiply(decay, tf.nn.l2_loss(W), name="l2_weight_loss") | |
tf.add_to_collection("losses", wd) | |
# stride of 1 in batch + channel dimensons for standard convolutions | |
conv = tf.nn.conv2d(x, | |
filter=W, | |
strides=[1, stride, stride, 1], | |
padding=padding, | |
data_format="NHWC") | |
return tf.add(conv,b, name=name) | |
def encode(x): | |
# Architecture | |
# Conv64 | |
with tf.variable_scope("encoder"): | |
with tf.variable_scope("conv1") as scope: | |
c1z = _conv_2d(tf.reshape(x, [-1, 28, 28, 1]), kernel_shape=[5, 5, 1, 32], | |
bias_shape=[32], decay=0.01) | |
c1y = tf.nn.relu(c1z, name=scope.name) | |
p1 = tf.nn.max_pool(c1y, ksize=[1,2,2,1], | |
strides=[1,2,2,1], padding="SAME") | |
with tf.variable_scope("conv2") as scope: | |
c2z = _conv_2d(p1, kernel_shape=[5, 5, 32, 64], | |
bias_shape=[64], decay=0.01) | |
c2y = tf.nn.relu(c2z, name=scope.name) | |
p2 = tf.nn.max_pool(c2y, ksize=[1,2,2,1], | |
strides=[1,2,2,1], padding="SAME") | |
with tf.variable_scope("affine1") as scope: | |
resh = tf.reshape(p2, [-1, 7*7*64]) | |
a1z = _affine(resh, weights_shape=[7*7*64, 1024], | |
bias_shape=[1024], decay=0.01) | |
a1y = tf.nn.relu(a1z, name=scope.name) | |
with tf.variable_scope("affine2") as scope: | |
a2z = _affine(a1y, weights_shape=[1024, 512], | |
bias_shape=[512], decay=0.01) | |
a2y = tf.nn.relu(a2z, name=scope.name) | |
with tf.variable_scope("affine3") as scope: | |
a3z = _affine(a2y, weights_shape=[512, 2], | |
bias_shape=[2], decay=0.01) | |
a3y = tf.nn.relu(a3z, name=scope.name) | |
return a3y | |
def decode(z): | |
with tf.variable_scope("decoder"): | |
with tf.variable_scope("affine1") as scope: | |
a1z = _affine(z, weights_shape=[2, 512], | |
bias_shape=[512], decay=0.01) | |
a1y = tf.nn.relu(a1z, name=scope.name) | |
with tf.variable_scope("affine2") as scope: | |
a2z = _affine(a1y, weights_shape=[512, 1024], | |
bias_shape=[1024], decay=0.01) | |
a2y = tf.nn.relu(a2z, name=scope.name) | |
with tf.variable_scope("affine3") as scope: | |
a3z = _affine(a2y, weights_shape=[1024, 7*7*64], | |
bias_shape=[7*7*64], decay=0.01, | |
name=scope.name) | |
a3y = tf.nn.relu(a3z, name=scope.name) | |
with tf.variable_scope("deconv1") as scope: | |
resh = tf.reshape(a3y, [-1, 7, 7, 64]) | |
up1 = tf.image.resize_images(resh, [14, 14], | |
method=tf.image.ResizeMethod.BILINEAR) | |
c1z = _conv_2d(up1, kernel_shape=[5, 5, 64, 32], | |
bias_shape=[32], decay=0.01) | |
c1y = tf.nn.relu(c1z, name=scope.name) | |
with tf.variable_scope("deconv2") as scope: | |
up2 = tf.image.resize_images(c1y, [28, 28], | |
method=tf.image.ResizeMethod.BILINEAR) | |
c2z = _conv_2d(up2, kernel_shape=[5, 5, 32, 1], | |
bias_shape=[1], decay=0.01) | |
c2y = tf.reshape(tf.nn.relu(c2z, name=scope.name), [-1, 784]) | |
return c2y | |
def run(): | |
with tf.Graph().as_default() as g: | |
X = Y = tf.placeholder("float", shape=[None, 784]) | |
with tf.device("/gpu:0"): | |
encode_op = encode(X) | |
pred = decode(encode_op) | |
loss = tf.reduce_mean(tf.pow(Y - pred, 2)) # MSE | |
optim = tf.train.AdamOptimizer(GAMMA) | |
train_op = optim.minimize(loss) | |
init = tf.global_variables_initializer() | |
saver = tf.train.Saver() | |
sess = tf.Session() | |
if FLAGS.train: | |
sess.run(init) | |
for e in range(EPOCHS): | |
batches = int(EPOCH_SIZE/BATCH_SIZE) | |
for i in range(batches): | |
xs, _ = mnist.train.next_batch(BATCH_SIZE) | |
_, l = sess.run([train_op, loss], feed_dict={X:xs}) | |
print("Epoch: {}/{}, Loss: {}".format(e, EPOCHS, l)) | |
save_path = saver.save(sess, "model.ckpt") | |
print("Done training") | |
sess.close() | |
else: | |
saver.restore(sess, "model.ckpt") | |
idxs = np.random.choice(len(mnist.test.images), 10) | |
if FLAGS.plot_point: | |
# generate plot | |
points = sess.run(encode_op, feed_dict={X: mnist.test.images[:10000]}) | |
labels = mnist.test.labels[:10000] | |
cs = dict() | |
for i in range(10): | |
cs[i] = ([],[]) | |
for point, label in zip(points, labels): | |
x = point[0] | |
y = point[1] | |
cs[label][0].append(x) | |
cs[label][1].append(y) | |
data = [cs[0], cs[1], cs[2], cs[3], cs[4], cs[5], cs[6], cs[7], cs[8], | |
cs[9]] | |
colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'] | |
groups = ["0","1","2","3","4","5","6","7","8","9"] | |
fig = plt.figure() | |
ax = fig.add_subplot(1, 1, 1, axisbg="1.0") | |
for data, color, group in zip(data, colors, groups): | |
x, y = data | |
ax.scatter(x, y, alpha=0.8, c=color, edgecolors='none', s=30, | |
label=group) | |
# Create plot | |
plt.title('Matplot scatter plot') | |
plt.legend(loc=2) | |
plt.show() | |
else: | |
images_to_show = np.take(mnist.test.images, idxs, axis=0) | |
encode_decode = sess.run(pred, feed_dict={X: images_to_show}) | |
sess.close() | |
f, a = plt.subplots(2, 10, figsize=(10, 2)) | |
for i in range(10): | |
a[0][i].imshow(np.reshape(images_to_show[i], (28, 28))) | |
a[0][i].axis('off') | |
a[1][i].axis('off') | |
a[1][i].imshow(np.reshape(encode_decode[i], (28, 28))) | |
f.save("plot.png") | |
#plt.draw() | |
#plt.waitforbuttonpress() | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment