Skip to content

Instantly share code, notes, and snippets.

@chychen
Last active May 24, 2021 13:07
Show Gist options
  • Save chychen/c8249b38cce94ae6a641260b81b23f78 to your computer and use it in GitHub Desktop.
Save chychen/c8249b38cce94ae6a641260b81b23f78 to your computer and use it in GitHub Desktop.
load omniglot from tfds, then create tf.data.Dataset based on it.
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical
from absl import flags, logging
from tqdm.auto import tqdm
FLAGS = flags.FLAGS
LOGGER = logging.get_absl_logger()
class Omniglot:
SAVE_PATH = '/tmp/omniglot/'
def __init__(self,
n_train_episode=5000,
n_valid_episode=1000,
img_size=(28, 28),
force_update=False,
n_train_class=1200):
self.n_train_episode = n_train_episode
self.n_valid_episode = n_valid_episode
self.c_way = FLAGS.c_way
self.k_shot = FLAGS.k_shot
self.img_shape = img_size + (1, )
self.n_train_class = n_train_class
self.np_dset = self._get_np_dset(img_size, force_update)
# get mean,std
self.mean = np.mean(self.np_dset[:self.n_train_class])
self.std = np.std(self.np_dset[:self.n_train_class])
LOGGER.info('mean: {}, stddev: {}'.format(self.mean, self.std))
# normalize
self.np_dset = (self.np_dset - self.mean) / self.std
self.data = self._get_tf_dsets(self.np_dset)
def get_data(self):
return self.data
def _get_np_dset(self, img_size, force_update):
if not os.path.exists(Omniglot.SAVE_PATH):
os.makedirs(Omniglot.SAVE_PATH)
dset_path = os.path.join(Omniglot.SAVE_PATH,
'dset_{}.npy'.format(img_size))
if not os.path.exists(dset_path) or force_update:
[
dset_all,
], info = tfds.load("omniglot:3.0.0",
with_info=True,
split=['train+test'])
def resize(x):
x['image'] = tf.image.resize(x['image'], size=img_size)
x['image'] = tf.image.rgb_to_grayscale(x['image'])
return x
dset_all = dset_all.map(resize)
np_dset = np.zeros(
shape=[info.features['label'].num_classes, 20, 28, 28, 1],
dtype=np.float32)
counter = np.zeros(shape=[
info.features['label'].num_classes,
],
dtype=np.int32)
for i, v in enumerate(tqdm(dset_all)):
c_idx = v['label'].numpy()
np_dset[c_idx, counter[c_idx]] = v['image'].numpy()
counter[c_idx] += 1
np.save(dset_path, np_dset)
else:
np_dset = np.load(dset_path)
return np_dset
def _get_tf_dsets(self, np_dset):
train_dataset = tf.data.Dataset.from_generator(
lambda: self._episode_generator(np_dset,
self.n_train_episode,
True,
c_way=self.c_way,
k_shot=self.k_shot,
img_shape=self.img_shape),
output_types=((np.float32, np.float32),
np.float32)).cache().shuffle(10000).batch(
FLAGS.bs, drop_remainder=True).prefetch(
buffer_size=tf.data.experimental.AUTOTUNE)
valid_dataset = tf.data.Dataset.from_generator(
lambda: self._episode_generator(np_dset,
self.n_valid_episode,
False,
c_way=self.c_way,
k_shot=self.k_shot,
img_shape=self.img_shape),
output_types=((np.float32, np.float32), np.float32)).cache().batch(
FLAGS.bs, drop_remainder=False).prefetch(
buffer_size=tf.data.experimental.AUTOTUNE)
return (train_dataset, valid_dataset)
def _episode_generator(self,
np_dset,
n_episode,
is_train,
c_way=5,
k_shot=1,
img_shape=[28, 28, 1]):
n_samples = np_dset.shape[1] # n_samples=20
if is_train:
pool = list(range(0, self.n_train_class))
n_query_per_c = (n_samples - k_shot)
n_query = n_query_per_c * c_way # n_query=19*5
else:
pool = list(range(self.n_train_class, np_dset.shape[0]))
n_query_per_c = 1
n_query = n_query_per_c * c_way # n_query=1*5
for i in range(n_episode):
support_set = np.zeros(shape=[n_query, c_way, k_shot, *img_shape],
dtype=np.float32)
query_set = np.zeros(shape=[n_query, c_way, k_shot, *img_shape],
dtype=np.float32)
y = np.zeros(shape=[n_query, c_way], dtype=np.float32)
selected_c_idx = np.random.choice(pool, c_way, replace=False)
selected_c_data = np_dset[selected_c_idx]
# rotation augment classes
rotate_degree = np.random.randint(low=0, high=4, size=c_way)
for i, v in enumerate(rotate_degree):
selected_c_data[i] = np.rot90(selected_c_data[i],
v,
axes=(-3, -2))
for c in range(c_way):
perm_idx = np.random.permutation(n_samples)
selected_s_idx = perm_idx[:k_shot]
# support set
selected_s_per_query = selected_c_data[
c, selected_s_idx] # k_shot data with same class
support_set[:, c:c + 1] = np.tile(selected_s_per_query,
[n_query, 1, 1, 1, 1, 1])
# query set
selected_q_idx = perm_idx[k_shot:k_shot +
n_query_per_c] # n_query_per_c
selected_q = selected_c_data[
c, selected_q_idx] # n_query data with same class
tmp_q = np.repeat(selected_q, k_shot * c_way, axis=0)
tmp_q = np.reshape(tmp_q,
[n_query_per_c, c_way, k_shot, *img_shape])
query_set[c * n_query_per_c:(c + 1) * n_query_per_c] = tmp_q
y[c * n_query_per_c:(c + 1) * n_query_per_c, c] = np.ones(
shape=[
n_query_per_c,
], dtype=np.float32)
yield (support_set, query_set), y
def vis(self, num_samples=10):
tfdset = self.data[0] # train
v = next(iter(tfdset))
s = v[0][0][0] # tuple->support_set->batch
q = v[0][1][0] # tuple->query_set->batch
y = v[1][0]
sampled_idx = np.random.permutation(q.shape[0])[:num_samples]
for n in sampled_idx:
fig = plt.figure(figsize=[6, 1])
for c in range(q.shape[1]):
k = 0
plt.subplot(1, 6, c + 1)
plt.imshow(s[n, c, k][..., 0], cmap='gray')
plt.subplot(1, 6, 6)
plt.title('query:{}'.format(y[n]))
plt.imshow(q[n, 0, k][..., 0], cmap='gray')
@chychen
Copy link
Author

chychen commented Jul 22, 2020

@dkk
please see https://gitlab.com/chychen/tf_rlib for tf_rlib, a personal lib for tensorflow2.
but I think you just remove tfrlib here is okay, I will update it later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment