Last active
March 27, 2020 17:40
-
-
Save cinjon/0017fdb9044903caaf54a9c338413119 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import gin | |
from meta_dataset.data import config | |
from meta_dataset.data import dataset_spec as dataset_spec_lib | |
from meta_dataset.data import learning_spec | |
from meta_dataset.data import pipeline | |
import numpy as np | |
import tensorflow as tf | |
import torch | |
GIN_FILE_PATH = 'metadataset/meta_dataset/learn/gin/setups/data_config.gin' | |
ALL_DATASETS = [ | |
'aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012', 'omniglot', | |
'quickdraw', 'vgg_flower' | |
] | |
gin.parse_config_file(GIN_FILE_PATH) | |
# Comment out to disable eager execution. | |
tf.enable_eager_execution() | |
np_to_torch_labels = lambda a: torch.from_numpy(a.numpy()).long() | |
np_to_torch_imgs = lambda a: torch.from_numpy( | |
np.transpose(a.numpy(), (0, 3, 2, 1))) | |
to_torch_labels = lambda a: torch.from_numpy(a).long() | |
to_torch_imgs = lambda a: torch.from_numpy(np.transpose(a, (0, 3, 2, 1))) | |
def iterate_dataset(dataset, num_batches, batch_size): | |
if not tf.executing_eagerly(): | |
iterator = dataset.make_one_shot_iterator() | |
next_element = iterator.get_next() | |
with tf.Session() as sess: | |
for idx in range(num_batches): | |
episode, source_id = sess.run(next_element) | |
yield (to_torch_imgs(episode[0]), to_torch_labels(episode[1]), | |
to_torch_imgs(episode[3]), to_torch_labels(episode[4])) | |
else: | |
batch_count = 0 | |
curr_batch = [] | |
for idx, (episode, source_id) in enumerate(dataset): | |
if batch_count == num_batches: | |
break | |
batch_entry = [ | |
np_to_torch_imgs(episode[0]), np_to_torch_labels(episode[1]), | |
np_to_torch_imgs(episode[3]), np_to_torch_labels(episode[4]) | |
] | |
curr_batch.append(batch_entry) | |
if len(curr_batch) == batch_size: | |
data_support = torch.stack([k[0] for k in curr_batch]) | |
labels_support = torch.stack([k[1] for k in curr_batch]) | |
data_query = torch.stack([k[2] for k in curr_batch]) | |
labels_query = torch.stack([k[3] for k in curr_batch]) | |
curr_batch = [] | |
batch_count += 1 | |
yield data_support, labels_support, data_query, labels_query | |
def pytorch_loader(fixed=True, | |
train=False, | |
test=False, | |
valid=False, | |
dataset=None, | |
num_support=None, | |
num_ways=None, | |
num_query=None, | |
batch_size=16, | |
num_batches=2, | |
base_path=None): | |
"""Pytorch loader. | |
We use the fixed ways and shots approach. See the repo for the others. | |
""" | |
print('Dataset: ', dataset) | |
if not train and not test and not valid: | |
raise | |
if train: | |
split = learning_spec.Split.TRAIN | |
elif test: | |
split = learning_spec.Split.TEST | |
elif valid: | |
split = learning_spec.Split.VALID | |
dataset_records_path = os.path.join(base_path, dataset) | |
dataset_spec = [dataset_spec_lib.load_dataset_spec(dataset_records_path)] | |
fixed_ways_shots = config.EpisodeDescriptionConfig(num_ways=num_ways, | |
num_support=num_support, | |
num_query=num_query) | |
dataset = pipeline.make_multisource_episode_pipeline( | |
dataset_spec_list=dataset_spec, | |
use_dag_ontology_list=[False], | |
use_bilevel_ontology_list=[False] * len(ALL_DATASETS), | |
split=split, | |
image_size=84, | |
episode_descr_config=fixed_ways_shots) | |
return iterate_dataset(dataset, num_batches, batch_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment