Created
July 10, 2022 17:06
-
-
Save ltiao/7c4b8dcdc97074897d784ba302394bd8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "da2f1945", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"2022-07-10 18:01:56.974325: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"import pandas as pd\n", | |
"import seaborn as sns\n", | |
"\n", | |
"from bayesian_benchmarks.data import Wilson_3droad\n", | |
"from gpflow_decomposed.benchmarking.plotting import WIDTH, HEIGHT" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "fb925dc9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rc = {\n", | |
" \"figure.figsize\": (WIDTH, HEIGHT),\n", | |
" \"figure.dpi\": 300,\n", | |
" \"font.serif\": [\"Palatino\"],\n", | |
" \"text.usetex\": True,\n", | |
"}\n", | |
"sns.set(context=\"talk\", style=\"ticks\", palette=\"crest\", font=\"serif\", rc=rc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "99a6a1e8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"split = 0\n", | |
"frac = 0.5\n", | |
"\n", | |
"seed = 8888\n", | |
"random_state = np.random.RandomState(seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "bb8ec495", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<bayesian_benchmarks.data.Wilson_3droad at 0x7f6d7426f9d0>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dataset = Wilson_3droad(split=split)\n", | |
"dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "f44f4dc0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>hue</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0.868575</td>\n", | |
" <td>0.765707</td>\n", | |
" <td>1.558066</td>\n", | |
" <td>-0.419957</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>-1.911345</td>\n", | |
" <td>-1.488496</td>\n", | |
" <td>-0.652020</td>\n", | |
" <td>1.146774</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0.142015</td>\n", | |
" <td>1.227965</td>\n", | |
" <td>0.865307</td>\n", | |
" <td>0.011288</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>-0.505180</td>\n", | |
" <td>1.050859</td>\n", | |
" <td>1.666834</td>\n", | |
" <td>-0.883993</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0.324750</td>\n", | |
" <td>-0.006274</td>\n", | |
" <td>-0.045162</td>\n", | |
" <td>-0.772770</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>43483</th>\n", | |
" <td>-0.437499</td>\n", | |
" <td>0.614072</td>\n", | |
" <td>-1.246918</td>\n", | |
" <td>-0.578184</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>43484</th>\n", | |
" <td>-1.986455</td>\n", | |
" <td>1.286251</td>\n", | |
" <td>1.251058</td>\n", | |
" <td>-0.932163</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>43485</th>\n", | |
" <td>1.230195</td>\n", | |
" <td>-1.982785</td>\n", | |
" <td>-1.363562</td>\n", | |
" <td>-0.199199</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>43486</th>\n", | |
" <td>-0.241771</td>\n", | |
" <td>0.417908</td>\n", | |
" <td>-0.177831</td>\n", | |
" <td>0.364876</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>43487</th>\n", | |
" <td>1.337289</td>\n", | |
" <td>0.884293</td>\n", | |
" <td>0.385982</td>\n", | |
" <td>4.331796</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>43488 rows × 4 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 hue\n", | |
"0 0.868575 0.765707 1.558066 -0.419957\n", | |
"1 -1.911345 -1.488496 -0.652020 1.146774\n", | |
"2 0.142015 1.227965 0.865307 0.011288\n", | |
"3 -0.505180 1.050859 1.666834 -0.883993\n", | |
"4 0.324750 -0.006274 -0.045162 -0.772770\n", | |
"... ... ... ... ...\n", | |
"43483 -0.437499 0.614072 -1.246918 -0.578184\n", | |
"43484 -1.986455 1.286251 1.251058 -0.932163\n", | |
"43485 1.230195 -1.982785 -1.363562 -0.199199\n", | |
"43486 -0.241771 0.417908 -0.177831 0.364876\n", | |
"43487 1.337289 0.884293 0.385982 4.331796\n", | |
"\n", | |
"[43488 rows x 4 columns]" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data = pd.DataFrame(data=dataset.X_test).assign(hue=dataset.Y_test.squeeze(axis=-1))\n", | |
"data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "1324ceb7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment