Created
April 11, 2019 18:59
-
-
Save bkj/8ae8da3c84bbb0fa06d144a6e7da8570 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
simple-random-nasbench.py | |
""" | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm, trange | |
from matplotlib import pyplot as plt | |
from nasbench.api import NASBench | |
np.random.seed(123) | |
# -- | |
# Helpers | |
cummax = np.maximum.accumulate | |
cumsum = np.cumsum | |
def cumargmax(x): | |
z = np.arange(x.shape[0], dtype=np.float) | |
z[x != cummax(x)] = np.nan | |
z = pd.Series(z).fillna(method='ffill') | |
return z.values.astype(int) | |
def sample_one_column(x): | |
i = np.arange(x.shape[0]) | |
j = np.random.choice(x.shape[1], x.shape[0], replace=True) | |
return x[(i, j)] | |
# -- | |
# IO | |
path = 'data/nasbench_only108.tfrecord' | |
api = NASBench(path) | |
# -- | |
# ETL | |
hashes = np.array(list(api.hash_iterator())) | |
n_models = len(hashes) | |
n_runs = 3 | |
test_acc = np.zeros((n_models, n_runs)) | |
valid_acc = np.zeros((n_models, n_runs)) | |
cost = np.zeros((n_models, n_runs)) | |
for i, h in tqdm(enumerate(hashes), total=len(hashes)): | |
_, result = api.get_metrics_from_hash(h) | |
result = result[108] | |
valid_acc[i] = [r['final_validation_accuracy'] for r in result] | |
test_acc[i] = [r['final_test_accuracy'] for r in result] | |
cost[i] = [r['final_training_time'] for r in result] | |
mean_valid_acc = valid_acc.mean(axis=-1) | |
mean_test_acc = test_acc.mean(axis=-1) | |
mean_cost = cost.mean(axis=-1) | |
# -- | |
# Random runs | |
def random_run(valid_acc, mean_test_acc, mean_cost, models_per_run=int(1e4)): | |
n_models = valid_acc.shape[0] | |
# Randomly sample `models_per_run` architectures w/o replacement | |
sel = np.random.choice(n_models, models_per_run, replace=False) | |
# Get 1 of the validation accuracies for the models | |
valid_acc_run = sample_one_column(valid_acc[sel]) | |
# Compute index of arch. w/ best validation accuracy so far | |
best_val_idx = cumargmax(valid_acc_run) | |
# Compute mean test accuracy for model w/ best validation accuracy so far | |
test_acc_run = mean_test_acc[sel][best_val_idx] | |
# Cumulative cost of run | |
cum_cost_run = cumsum(mean_cost[sel]) | |
return test_acc_run, cum_cost_run | |
rand_results = [random_run(valid_acc, mean_test_acc, mean_cost) for _ in trange(500)] | |
test_acc_runs, cum_cost_runs = list(zip(*rand_results)) | |
# Average test acc of selected models | |
mean_test_acc_run = np.stack(test_acc_runs).mean(axis=0) | |
# Average cumulative cost of random runs | |
mean_cum_cost_run = np.stack(cum_cost_runs).mean(axis=0) | |
_ = plt.plot(mean_cum_cost_run, mean_test_acc.max() - mean_test_acc_run, c='red') | |
_ = plt.xscale('log') | |
_ = plt.yscale('log') | |
_ = plt.ylim(1e-3, 1e-1) | |
_ = plt.legend() | |
_ = plt.grid(which='both', alpha=0.5) | |
_ = plt.axhline(4e-3, c='grey', alpha=0.26) | |
plt.show() | |
# Performance at 1e7 seconds is `5.5 * 1e-3`, compared to about `4.1 * 1e-3` in the paper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment