Skip to content

Instantly share code, notes, and snippets.

@junpenglao
Last active January 13, 2020 15:13
Show Gist options
  • Save junpenglao/51cd25c6372f8d2ab3490d4af8f97401 to your computer and use it in GitHub Desktop.
Save junpenglao/51cd25c6372f8d2ab3490d4af8f97401 to your computer and use it in GitHub Desktop.
TFP_NUTS_demo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "TFP_NUTS_demo.ipynb",
"provenance": [],
"collapsed_sections": [
"uiR4-VOt9NFX",
"l75JjxxkXHWp"
],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/junpenglao/51cd25c6372f8d2ab3490d4af8f97401/tfp_nuts_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R_rU1_eX-EWD",
"colab_type": "text"
},
"source": [
"Using the example from [tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution.ipynb)\n",
"\n",
"Visualized using [Arviz](https://arviz-devs.github.io/arviz/index.html)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "uiR4-VOt9NFX"
},
"source": [
"### Dependencies & Prerequisites\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "IVUf-En1zda6",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 86
},
"outputId": "91cfcb86-d854-45ba-9084-fc932bc7535c"
},
"source": [
"!pip3 install -q --upgrade tf-nightly-gpu tfp-nightly\n",
"!pip3 install -q --upgrade git+git://github.com/arviz-devs/arviz.git"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for arviz (PEP 517) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "coUnDhkpT5_6",
"colab": {}
},
"source": [
"from __future__ import absolute_import\n",
"from __future__ import division\n",
"from __future__ import print_function\n",
"\n",
"from pprint import pprint\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import pandas as pd\n",
"import arviz as az\n",
"\n",
"import tensorflow.compat.v1 as tf1\n",
"import tensorflow.compat.v2 as tf\n",
"import tensorflow_probability as tfp\n",
"\n",
"sns.reset_defaults()\n",
"#sns.set_style('whitegrid')\n",
"#sns.set_context('talk')\n",
"sns.set_context(context='talk',font_scale=0.7)\n",
"\n",
"%config InlineBackend.figure_format = 'retina'\n",
"%matplotlib inline\n",
"\n",
"tfd = tfp.distributions\n",
"tfb = tfp.bijectors\n",
"NUTS = tfp.mcmc.NoUTurnSampler\n",
"\n",
"dtype=tf.float32"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4j4Q5mBtZSiw",
"colab_type": "code",
"outputId": "32794816-d9a0-4d14-a1a6-fc85f688583f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 103
}
},
"source": [
"import sys\n",
"print(\"Python version\")\n",
"print(sys.version)\n",
"print(\"Version info.\")\n",
"print(sys.version_info)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"Python version\n",
"3.6.9 (default, Nov 7 2019, 10:44:02) \n",
"[GCC 8.3.0]\n",
"Version info.\n",
"sys.version_info(major=3, minor=6, micro=9, releaselevel='final', serial=0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WYEnkKtRNoWO",
"colab_type": "code",
"outputId": "9680652f-039e-401f-82bd-ed97e1769666",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 86
}
},
"source": [
"with tf1.Session() as session:\n",
" pprint(session.list_devices())\n",
"\n",
"if tf.test.gpu_device_name() != '/device:GPU:0':\n",
" USE_XLA = False\n",
"else:\n",
" USE_XLA = True"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"[_DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 268435456, 11814214708597102168),\n",
" _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 1383684441020902817),\n",
" _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 4812920158082693604),\n",
" _DeviceAttributes(/job:localhost/replica:0/task:0/device:GPU:0, GPU, 11330115994, 16556873309225050192)]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "T6lzcA0-ZlAv",
"colab_type": "code",
"outputId": "66c1dea3-abb4-4faf-bb45-94fbf47a13f4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
},
"source": [
"print(\"Eager mode: {}\".format(tf.executing_eagerly()))\n",
"print(\"XLA: {}\".format(USE_XLA))"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"Eager mode: True\n",
"XLA: True\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l75JjxxkXHWp",
"colab_type": "text"
},
"source": [
"# Simple multidimensional Gaussian example\n",
"With some light benchmark showing the advantage of compiling to XLA."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ImeuQkOkFAB6",
"colab_type": "code",
"colab": {}
},
"source": [
"nsamples, nchains = 500, 10\n",
"nd = 5\n",
"\n",
"theta0 = np.zeros((nchains, nd))\n",
"mu = np.arange(nd)*5.\n",
"w = np.random.randn(nd, nd)*.1\n",
"cov = w*w.T + np.diagflat(np.arange(nd)+1.)\n",
"step_size = np.random.rand(nchains, 1)*.5 + 1.\n",
"\n",
"scale_tril = np.linalg.cholesky(cov)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "2fCx2k9Y_jTq",
"colab": {}
},
"source": [
"def run_nuts():\n",
" event_size=nd\n",
" batch_size=nchains\n",
" num_steps=nsamples\n",
" initial_state=tf.cast(theta0, dtype=dtype)\n",
" \n",
" def trace_fn(_, pkr):\n",
" return (\n",
" pkr.inner_results.target_log_prob,\n",
" pkr.inner_results.leapfrogs_taken,\n",
" pkr.inner_results.has_divergence,\n",
" pkr.inner_results.energy,\n",
" pkr.inner_results.log_accept_ratio\n",
" )\n",
"\n",
" def target_log_prob_fn(event):\n",
" return tfd.MultivariateNormalTriL(\n",
" tf.cast(mu, dtype=dtype), \n",
" tf.cast(scale_tril, dtype=dtype)).log_prob(event)\n",
"\n",
" unrolled_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
" inner_kernel=NUTS(\n",
" target_log_prob_fn,\n",
" step_size=[tf.cast(step_size, dtype=dtype)]),\n",
" num_adaptation_steps=50,\n",
" step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(step_size=new_step_size),\n",
" step_size_getter_fn=lambda pkr: pkr.step_size,\n",
" log_accept_prob_getter_fn=lambda pkr: pkr.\n",
" log_accept_ratio,\n",
")\n",
"\n",
" [chain_state], sampler_stat = tfp.mcmc.sample_chain(\n",
" num_results=num_steps,\n",
" num_burnin_steps=50,\n",
" current_state=[initial_state],\n",
" kernel=unrolled_kernel,\n",
" trace_fn=trace_fn)\n",
" return chain_state, sampler_stat"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "EZgG4sw7ccRX",
"outputId": "00633034-3766-4492-b786-ec56a5735cf3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 89
}
},
"source": [
"run_nuts_defun = tf.function(run_nuts, autograph=False)\n",
"\n",
"samples, sampler_stat = run_nuts_defun()"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Do not pass `graph_parents`. They will no longer be used.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"outputId": "a78153f0-d51a-49a3-de07-7469b6e64840",
"id": "suQ2aYkBccRi",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"%%timeit\n",
"samples, sampler_stat = run_nuts_defun()"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"1 loop, best of 3: 1min 12s per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KRA5Dv3UxX1j",
"colab_type": "text"
},
"source": [
"**~10x speed up using XLA in GPU :-).**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Q9ZHuCBqcaLo",
"colab_type": "code",
"colab": {}
},
"source": [
"samples, sampler_stat = tf.xla.experimental.compile(run_nuts)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "AB7jZjv5_jTt",
"outputId": "b12029ac-88b7-49cd-c5d5-5bb096b8dd12",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"%%timeit\n",
"samples, sampler_stat = tf.xla.experimental.compile(run_nuts)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"1 loop, best of 3: 8.57 s per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "66Xoi-DIp2Id",
"colab_type": "text"
},
"source": [
"**Note that now you can use `experimental_compile` in `tf.function` which compile the function to XLA**"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "vo6JU4-IqICj",
"colab": {}
},
"source": [
"run_nuts_defun_xla = tf.function(run_nuts, autograph=False, experimental_compile=True)\n",
"\n",
"samples, sampler_stat = run_nuts_defun_xla()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "GQjWnALWqIC1",
"outputId": "62a58dc6-8a1e-4d8f-f883-eede383d8d87",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"%%timeit\n",
"samples, sampler_stat = run_nuts_defun_xla()"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"1 loop, best of 3: 4.97 s per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "A9AK56yaCrFi",
"colab_type": "code",
"colab": {}
},
"source": [
"# using the pymc3 naming convention\n",
"sample_stats_name = ['lp', 'tree_size', 'diverging', 'energy', 'mean_tree_accept']\n",
"\n",
"sample_stats = {k:v.numpy().T for k, v in zip(sample_stats_name, sampler_stat)}\n",
"posterior = {'z': tf.transpose(samples, [1, 0, 2]).numpy()} \n",
"\n",
"az_trace = az.from_dict(posterior=posterior, sample_stats=sample_stats)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CYGP6dlgG7w0",
"colab_type": "code",
"outputId": "ae92df70-f668-4163-d729-b0fd96baa748",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 743
}
},
"source": [
"az.plot_trace(az_trace);"
],
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment