Created
February 18, 2019 06:50
-
-
Save lidopypy/b53b2f9757357057d389c92945612e88 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
from tensorflow.contrib.tensorboard.plugins import projector | |
#數據預處理 | |
test_data = np.array(pd.read_csv(r'C:\\Users\\lido_lee\\Downloads\\fasion_mnist\\fashion-mnist_test.csv'), dtype='float32') | |
embed_count = 1600 | |
x_test = test_data[:embed_count, 1:] / 255 | |
y_test = test_data[:embed_count, 0] | |
#建立資料夾,可以將callback儲存的log丟進來 | |
logdir = 'C:\\Users\\lido_lee\\Downloads\\fmnist_callbacks' | |
# setup the write and embedding tensor | |
summary_writer = tf.summary.FileWriter(logdir) | |
embedding_var = tf.Variable(x_test, name='fmnist_embedding') | |
config = projector.ProjectorConfig() | |
embedding = config.embeddings.add() | |
embedding.tensor_name = embedding_var.name | |
embedding.metadata_path = os.path.join(logdir, 'metadata.tsv') | |
embedding.sprite.image_path = os.path.join(logdir, 'sprite.png') | |
embedding.sprite.single_image_dim.extend([28, 28]) | |
projector.visualize_embeddings(summary_writer, config) | |
# run the sesion to create the model check point | |
with tf.Session() as sesh: | |
sesh.run(tf.global_variables_initializer()) | |
saver = tf.train.Saver() | |
saver.save(sesh, os.path.join(logdir, 'model.ckpt')) | |
# create the sprite image and the metadata file | |
rows = 28 | |
cols = 28 | |
label = ['t_shirt', 'trouser', 'pullover', 'dress', 'coat', | |
'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boot'] | |
sprite_dim = int(np.sqrt(x_test.shape[0])) | |
sprite_image = np.ones((cols * sprite_dim, rows * sprite_dim)) | |
index = 0 | |
labels = [] | |
for i in range(sprite_dim): | |
for j in range(sprite_dim): | |
labels.append(label[int(y_test[index])]) | |
sprite_image[ | |
i * cols: (i + 1) * cols, | |
j * rows: (j + 1) * rows | |
] = x_test[index].reshape(28, 28) * -1 + 1 | |
index += 1 | |
with open(embedding.metadata_path, 'w') as meta: | |
meta.write('Index\tLabel\n') | |
for index, label in enumerate(labels): | |
meta.write('{}\t{}\n'.format(index, label)) | |
plt.imsave(embedding.sprite.image_path, sprite_image, cmap='gray') | |
plt.imshow(sprite_image, cmap='gray') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment