Skip to content

Instantly share code, notes, and snippets.

@ltiao
Created July 10, 2022 17:06
Show Gist options
  • Save ltiao/7c4b8dcdc97074897d784ba302394bd8 to your computer and use it in GitHub Desktop.
Save ltiao/7c4b8dcdc97074897d784ba302394bd8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "da2f1945",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-07-10 18:01:56.974325: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"\n",
"from bayesian_benchmarks.data import Wilson_3droad\n",
"from gpflow_decomposed.benchmarking.plotting import WIDTH, HEIGHT"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fb925dc9",
"metadata": {},
"outputs": [],
"source": [
"rc = {\n",
" \"figure.figsize\": (WIDTH, HEIGHT),\n",
" \"figure.dpi\": 300,\n",
" \"font.serif\": [\"Palatino\"],\n",
" \"text.usetex\": True,\n",
"}\n",
"sns.set(context=\"talk\", style=\"ticks\", palette=\"crest\", font=\"serif\", rc=rc)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "99a6a1e8",
"metadata": {},
"outputs": [],
"source": [
"split = 0\n",
"frac = 0.5\n",
"\n",
"seed = 8888\n",
"random_state = np.random.RandomState(seed)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bb8ec495",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<bayesian_benchmarks.data.Wilson_3droad at 0x7f6d7426f9d0>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = Wilson_3droad(split=split)\n",
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f44f4dc0",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>hue</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.868575</td>\n",
" <td>0.765707</td>\n",
" <td>1.558066</td>\n",
" <td>-0.419957</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-1.911345</td>\n",
" <td>-1.488496</td>\n",
" <td>-0.652020</td>\n",
" <td>1.146774</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.142015</td>\n",
" <td>1.227965</td>\n",
" <td>0.865307</td>\n",
" <td>0.011288</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>-0.505180</td>\n",
" <td>1.050859</td>\n",
" <td>1.666834</td>\n",
" <td>-0.883993</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.324750</td>\n",
" <td>-0.006274</td>\n",
" <td>-0.045162</td>\n",
" <td>-0.772770</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43483</th>\n",
" <td>-0.437499</td>\n",
" <td>0.614072</td>\n",
" <td>-1.246918</td>\n",
" <td>-0.578184</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43484</th>\n",
" <td>-1.986455</td>\n",
" <td>1.286251</td>\n",
" <td>1.251058</td>\n",
" <td>-0.932163</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43485</th>\n",
" <td>1.230195</td>\n",
" <td>-1.982785</td>\n",
" <td>-1.363562</td>\n",
" <td>-0.199199</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43486</th>\n",
" <td>-0.241771</td>\n",
" <td>0.417908</td>\n",
" <td>-0.177831</td>\n",
" <td>0.364876</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43487</th>\n",
" <td>1.337289</td>\n",
" <td>0.884293</td>\n",
" <td>0.385982</td>\n",
" <td>4.331796</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>43488 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 hue\n",
"0 0.868575 0.765707 1.558066 -0.419957\n",
"1 -1.911345 -1.488496 -0.652020 1.146774\n",
"2 0.142015 1.227965 0.865307 0.011288\n",
"3 -0.505180 1.050859 1.666834 -0.883993\n",
"4 0.324750 -0.006274 -0.045162 -0.772770\n",
"... ... ... ... ...\n",
"43483 -0.437499 0.614072 -1.246918 -0.578184\n",
"43484 -1.986455 1.286251 1.251058 -0.932163\n",
"43485 1.230195 -1.982785 -1.363562 -0.199199\n",
"43486 -0.241771 0.417908 -0.177831 0.364876\n",
"43487 1.337289 0.884293 0.385982 4.331796\n",
"\n",
"[43488 rows x 4 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.DataFrame(data=dataset.X_test).assign(hue=dataset.Y_test.squeeze(axis=-1))\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1324ceb7",
"metadata": {},
"outputs": [
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment