Skip to content

Instantly share code, notes, and snippets.

@CihanSoylu
Last active January 11, 2020 23:03
Show Gist options
  • Save CihanSoylu/48f1be6038291cf446b3c3e4420922ed to your computer and use it in GitHub Desktop.
Save CihanSoylu/48f1be6038291cf446b3c3e4420922ed to your computer and use it in GitHub Desktop.
tensorflow_probability.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "tensorflow_probability.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/CihanSoylu/48f1be6038291cf446b3c3e4420922ed/tensorflow_probability.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RqLHXJ7FhiFT",
"colab_type": "code",
"outputId": "8452311e-bdcb-43ce-b3f4-ca0b83018f98",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"%tensorflow_version 2.x"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"TensorFlow 2.x selected.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bYU2iEhGh0Q4",
"colab_type": "code",
"outputId": "7eb0fb68-fa2f-4680-a805-9b5db247ea7b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"print(tf.__version__)\n",
"\n",
"tfd = tfp.distributions"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"2.0.0-beta1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aPqN8COsG3TS",
"colab_type": "text"
},
"source": [
"# Distributions"
]
},
{
"cell_type": "code",
"metadata": {
"id": "E0mmSQ7pHce-",
"colab_type": "code",
"outputId": "9ea0130d-6181-4c5d-ca65-4e64e1e2ac0c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Create a normal distribution\n",
"n = tfd.Normal(loc=0., scale=1.)\n",
"n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Normal 'Normal/' batch_shape=[] event_shape=[] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EbSxp2NoJoNn",
"colab_type": "text"
},
"source": [
"Event shape = ( ) means it is a scalar distribution. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "W12A0REIIKU7",
"colab_type": "code",
"outputId": "a8c572aa-b7dc-4970-ea80-d952e20e80d8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Draw a sample from it\n",
"n.sample()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=22, shape=(), dtype=float32, numpy=0.95941013>"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "lWms-qtRINqW",
"colab_type": "code",
"outputId": "7388d2d3-4135-4188-a8fb-e3a5dbeef9ab",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Draw three samples from it\n",
"n.sample(3)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=47, shape=(3,), dtype=float32, numpy=array([-1.2171098, 1.2464374, -1.3972762], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "gAeua0UVIZlv",
"colab_type": "code",
"outputId": "522885d3-9d13-4402-9e78-1ed023e761b0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Evaluate the log prob\n",
"n.log_prob(0.)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=59, shape=(), dtype=float32, numpy=-0.9189385>"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LJCEvsfGI4iU",
"colab_type": "code",
"outputId": "d472e034-c4bc-49f2-a477-77077083b544",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"b = tfd.Bernoulli(probs=0.7)\n",
"b"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Bernoulli 'Bernoulli/' batch_shape=[] event_shape=[] dtype=int32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6OYr77IFI-G9",
"colab_type": "code",
"outputId": "a2f80ceb-3f8c-4fe7-e357-8610de143560",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"b.sample(8)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=90, shape=(8,), dtype=int32, numpy=array([0, 1, 1, 1, 1, 1, 1, 1], dtype=int32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Qn95ZtCnJFt7",
"colab_type": "code",
"outputId": "7a48c8a5-2d22-486c-b63e-0edb3012c9b5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
}
},
"source": [
"b.log_prob([1,0])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING: Logging before flag parsing goes to stderr.\n",
"W0904 13:10:26.867488 140202921670528 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py:182: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=113, shape=(2,), dtype=float32, numpy=array([-0.35667494, -1.2039728 ], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NC7M7LbcJSz2",
"colab_type": "code",
"outputId": "78aa4804-52f4-4ac1-f2a8-578941657def",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Multivariate normal with diagonal covariance\n",
"nd = tfd.MultivariateNormalDiag(loc=[0., 10.], scale_diag=[1., 4.])\n",
"nd"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag/' batch_shape=[] event_shape=[2] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xSchAFkSJv2a",
"colab_type": "text"
},
"source": [
"Event shape = (2) means the sample space is 2 dimensional."
]
},
{
"cell_type": "code",
"metadata": {
"id": "3_8nsDkjJh0E",
"colab_type": "code",
"outputId": "889ef915-5520-49cd-8949-2de70262fdfc",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"nd.sample()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=194, shape=(2,), dtype=float32, numpy=array([-0.7479573, 16.969402 ], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6aIWET43KGzK",
"colab_type": "code",
"outputId": "5e576817-a61f-4ffd-89b2-7a8156554043",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
}
},
"source": [
"nd.sample(4)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=246, shape=(4, 2), dtype=float32, numpy=\n",
"array([[ 0.33959457, 14.147322 ],\n",
" [ 0.05219507, 17.290096 ],\n",
" [-0.2815908 , 6.8896832 ],\n",
" [-0.32239607, 10.473448 ]], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "iK6Qm0NCKavr",
"colab_type": "code",
"outputId": "d5ffe56c-2f42-4ac2-a678-a03b80c42955",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"nd.log_prob([0., 10])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=306, shape=(), dtype=float32, numpy=-3.2241714>"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mV1JiAp7KkfG",
"colab_type": "code",
"outputId": "c6651cb9-f1b4-48bf-eec6-da71a3d7c584",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
}
},
"source": [
"import matplotlib.pyplot as plt\n",
"nd = tfd.MultivariateNormalFullCovariance(\n",
" loc = [0., 5], covariance_matrix = [[1., .7], [.7, 1.]])\n",
"data = nd.sample(200)\n",
"plt.scatter(data[:, 0], data[:, 1], color='blue', alpha=0.4)\n",
"plt.axis([-5, 5, 0, 10])\n",
"plt.title(\"Data set\")\n",
"plt.show()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJztnXlwXeWZ5p9XkrXZknVlJBnJxsa2\nsCDgLWYJDiHExnQStlRNZhKSXkKqqGEmTbqr6W1CJTSpmenqUOnuqXTRYQLTrmA6nU6YhM4yNnYg\nS7ME23gBW1g2q2QjC1uybEuWLOubPx59dY6u75XuJt2rc59flUq+5557zqcL9Xzveb73fT9zzkEI\nIcTMpyTfAxBCCJEbJOhCCBERJOhCCBERJOhCCBERJOhCCBERJOhCCBERJOhCCBERJOhiRmBmb5nZ\noJmdMrM+M3vezP6zmaX0/7CZLTYzZ2ZlUz3WuPt+1Mw6p/OeoniRoIuZxG3OuRoAiwD8NYA/B/BY\nfockROEgQRczDufcSefc0wD+E4DfN7MrAcDMPmlmr5hZv5m9a2YPhj72q7HffWZ22sw+ZGZLzewX\nZnbczN43s81mVpfonkb+1syOjV1/X+i+FWb2sJm9Y2bdZvaPZlZlZrMB/BxA89g9T5tZ85R9MaLo\nkaCLGYtz7rcAOgHcMHboDIDfA1AH4JMA7jWzO8fe+8jY7zrn3Bzn3AsADMD/BNAM4HIACwE8mOR2\nG8eucRmAuQD+I4DjY+/99djxVQCWAWgB8FXn3BkAHwdwZOyec5xzR7L9u4VIhgRdzHSOAKgHAOfc\nc865fc65UefcXgD/DODGZB90zh1yzj3jnBtyzvUA+OYE558DUAOgDYA55w44546amQG4B8AfO+dO\nOOdOAfgfAD6Ts79QiBSZ1gUiIaaAFgAnAMDMrgWj5SsBlAOoAPCvyT5oZk0A/h6M8GvAAKc30bnO\nuV+Y2bcA/AOARWb2FID7AVQCqAawk9rOSwMozfYPEyJdFKGLGYuZXQ0K+m/GDj0J4GkAC51zcwH8\nIyiuAJCorej/GDt+lXOuFsDnQ+dfgHPufznnPgjgCtBi+VMA7wMYBPAB51zd2M9c59ycCe4rxJQg\nQRczDjOrNbNbAXwPwBPOuX1jb9UAOOGcO2tm1wC4K/SxHgCjAJaEjtUAOA3gpJm1gAKd7J5Xm9m1\nZjYL9OrPAhh1zo0C+N8A/tbMGsfObTGzW8Y+2g1gnpnNzfLPFmJSJOhiJvFvZnYKwLsAvgJ63l8I\nvf9fADw0ds5XAXzfv+GcGwDw3wH8+1ge+3UA/grAGgAnAfwUwFMT3LsWFO5eAG+DC6LfGHvvzwEc\nAvCimfUD2AZg+dh920Ev/42x+yrLRUwZpg0uhBAiGihCF0KIiDCpoJvZ42PFFK+GjtWb2TNm1jH2\nOza1wxRCCDEZqUTo/wTgd+KO/QWA7c65VgDbx14LIYTIIyl56Ga2GMBPnHO+1Pl1AB8dK6y4GMBz\nzrnlUzlQIYQQE5NpYVGTc+7o2L/fA9CU7EQzuwespMPs2bM/2NbWluEthRCiONm5c+f7zrmGyc7L\nulLUOefMLGmY75x7FMCjALB27Vq3Y8eObG8phBBFhZm9ncp5mWa5dI9ZLRj7fSzD6wghhMgRmQr6\n0wB+f+zfvw/gx7kZjhBCiExJJW3xnwG8AGC5mXWa2RfBBkg3m1kHgA1jr4UQQuSRST1059xnk7y1\nPsdjEUIIkQWqFBVCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIgg\nQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdC\niIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIgg\nQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIggQRdCiIhQlu8BCCGiR0cHsG0b0NUFtLQAGzYA\nra35HlX0UYQuhMgpHR3AY48Bp04BCxbw92OP8biYWrISdDP7YzN7zcxeNbN/NrPKXA1MCDEz2bYN\nqK8H6uqAkhL+rq/ncTG1ZCzoZtYC4D4Aa51zVwIoBfCZXA1MCDEz6eoCamvHH6utBY4cyc94iols\nLZcyAFVmVgagGoD+kwlR5LS0AP3944/19wPNzfkZTzGRsaA757oAPAzgHQBHAZx0zm2NP8/M7jGz\nHWa2o6enJ/ORCiFmBBs2ACdOAH19wOgof584weNiasnGcokBuAPApQCaAcw2s8/Hn+ece9Q5t9Y5\nt7ahoSHzkQohZgStrcAXvwjU1NB+qanha2W5TD3ZpC1uAPCmc64HAMzsKQDXA3giFwMTQsxcWlsl\n4PkgGw/9HQDXmVm1mRmA9QAO5GZYQggh0iUbD/0lAD8AsAvAvrFrPZqjcQkhhEiTrCpFnXNfA/C1\nHI1FCCFEFqhSVAghIoJ6uQghska9WwoDRehCiKxQ75bCQYIuhMgK9W4pHCToQoisUO+WwkGCLoTI\nCvVuKRwk6EKIrFDvlsJBgi6EyAr1bikclLYohMg67VC9WwoDRehCFDlKO4wOEnQhihylHUYHWS5C\nRIhMrJOuLkbmYWpreVzMLCToQkQEb53U11Og+/v5erIFSp92WFcHdHcD7e3Ae+8BjY2B7aKy/pmB\nLBchIkKm1olPO+zoAP7934HeXqCsjJPCww8D3/iG/PWZggRdiCzp6AAeeQR44AH+zpfYZVqx6dMO\nOzuBc+eAWAz48Id5/NgxoKdH/vpMQZaLEGkS9qlnzeLvZcvSszmmgrB14km1YrO1FVi6FLjxRgq3\n5+xZwGz8ufLXCxdF6EKkQXyK3+7dwMGDwNBQ+hFsriP7bCs2E5XwV1YCFRXjj6msv3CRoIuiIlsR\njfeph4ZoUbS3B+ekYnNMRe53thWbiSaExkagoUFl/TMFWS6iaMg0CyRMfIpfLAYMDlLoPKlEsOGJ\nAQh+b9uWnVWTTcWmnxC8ndTcDNx/fzAuf+xTn1KWS6EiQRdFQy5END7Fr78f2LuXr48eBaqqGMF+\n6lMTXydXud+53iko2YTQ2hrca9MmpS8WKrJcRNGQi77d8Sl+588Dl1xCa2LLFmBgYHzEn8ziyUXL\n2Y4OphRu3Qrs2sXf3/jG1GTZqD3AzECCLoqGXIhoohS/jRuBT38auPNOoKlpvJiHRfCNN4AvfQm4\n915G94cPZ+dNb97MBVkzjsOMrzdvTv0aqaL2ADMDWS6iaNiwgQILMDLv70/NHoknWYpfvGUSFsHu\nbmDfPqC0lOJ96aWAc/TfT53KzJt+6SUKeXU1X1dX85ovvcTXubRj1B5gZiBBF0VDokW/TBf4kuV8\nl5bSWunqAnbuBK67jue0twOzZzMNsK+Px5YtYybKvfdOfr9E4gxcmCNuxp9cLACn8vcqfbGwkKCL\noiJXfbsTRfuHDlFMq6spoq+9Bjz7LLB+PcvpfUaMF8VwhDtRNJ1MnFtbuSBrxoni7Fk+cdx0E691\n/jywZ09w7+bmzLNocvV0I6YWeehCZECinO+WFlox3mdes4bn7trFY729XDRta+NxH+FOtuCYzL+e\nNw9YvpwefG8vfy9fDtx1F+2dvXsp8rEYy/d/+lPgO9/JLP9euxLNDBShC5Eh8dH+Aw9wUdTT1ASs\nWAFs3850xjNngOuvDwp1fIQbn045NAS8/jpw333A7bdTnFesGH/v2loK//338/NHjnBy8JF9by/t\nn6oqntfVxYjdLJgwvvhFXitVn127EhU+EnQhckS8z9zdzSj5ssuAW25hVsvevSylv+qqwL/ftClY\ncOzuZjpkdXUgvocP038Pi6mP7pOJbPiJ4OhRRvalpbyOH9/mzYzg6+vZk2brVuC73+VY77pL4j0T\nkeUiRI6IL53ftYvH16yhoLa20k+/6iouhHrBDKdT+sVTgFbJ0BCv9cMfAj//OcU5lRTHFSt4H1/o\nVFnJ+/iJo7aW2TD19bzHCy9wApk/n/1plGM+M5GgC5Ej4n3ms2e5QBm2YRIVMoUnghMnKOADA8BF\nFzFanzMHWLiQ5yYqXkrEhg2MyFeuBK65hjZPael4/96Pp72dTwTV1ZwAhoeVYz5TkeUiRA4JWyCP\nPELLBAh2Ajp6lALf0RGcF06n9GmI69YFQmtG4a2uptj/5CfsU37VVck97/A1YzHg+HFG7WH//rrr\nKOw+jRLgJFRXpxzzmYoEXYgpwqf69fTQOy8tBcrLKcxf+hIzYsKi3NoafKaigh54ZSXF24zWSH8/\nI+h9+zgpfPe7rFT93OcuFPbw5BJOi/T59wDvVV7OqL+khAu3q1Ypx3ymYs65abvZ2rVr3Y4dO6bt\nfkIAuW9gle69v/pVivL8+bRRXn89SD9ctYrRcnz/l23bgKefpqCPjFB0OzspuCUljKRnzwYWL+Zn\nli8PrpHO39vRwcXRrVv55LBqVeC7Ky2xcDCznc65tZOeJ0EXUSPZjkLhgpipEqtEYuqzWEpKgF/+\nksVFVVW0Ou64g78TVYz6/PTf/paTwa5dvEZZGaP9kRHaKH19bENQUxNE+PX16f29ftzx6Y+iMEhV\n0LOyXMysDsB3AFwJwAG42zn3QjbXFCIb4qsqt2yhdbFwYRAVA4xKm5pyG7Unq+isrAzSGb1f7b1q\nILlf7X3wt9+mB19TE7TpLS2ldeMrT/01Mm0RrBzzaJCth/73AP6fc+4/mFk5gOocjEmIjElUpON3\nFPLZJoODtBjuvDO3+4DGN+Pyi6BVVbzP0qUUXl/VuWoVP3f4MO2UBx5g9G3GTo5+onnoIY7v/Hl6\n5wBw+jSwZAm979WrA89bTbSKm4wF3czmAvgIgD8AAOfcMIDh3AxLiMxIZUeh3bsp7rneLcjf2xcH\nzZ5Nq6S7m10QOzv509VFcR8dpZ2yfTv7qZ86xQXUqiqmO546xf7mPk+9t5cTRmkprZHqak4KFRXj\nq07VRKt4ySZCvxRAD4D/Y2YrAewE8GXn3JnwSWZ2D4B7AOCSSy7J4nZCTE58tWZbGwUzFqOA9vcD\nb71Fgfvxj3leWxvT+bKNYv29fXFQVRUj6Pnzgblzmely222cYHbv5v1PnqQdtHAhsH8/Fz3nzmVf\n87Y2/u7pYfVm2A8HAs+7pmZ810g10SpesiksKgOwBsAjzrnVAM4A+Iv4k5xzjzrn1jrn1jY0NGRx\nOyEmJ75as6KCGSArV1KwBwYots4x0t2/H/infwJ+8AO+zsW9jx7lfQcGgmZcnZ1cxPQNtqqrKd6n\nTvGYGd+fPZtj7+vjxBCLMU2xp4edE3/7W2bNAFxE/frXx1edqolWcZNNhN4JoNM5N9ZOHz9AAkEX\nYjpJttFxuNjnQx8CXnyRVoiPpN9+mzaIL/jJJNUxfhFz/nz6201NwHPP8XfYjvF+eXs7cPnlHMe5\nc4yqW1sp6hUVjOi/9z166HPm8PVEnr8WOIuXjAXdOfeemb1rZsudc68DWA9gf+6GJkRmJBO0jg7m\ndgMUS5/6V13N/PClS4Ny91Q3h0gk/H4R06cO9vVRvFtaxtsx5eWMwJ0D3nkn6KFeVsaGXrt387rH\njwcLpceP8/0rrsje8xfRI9sslz8EsHksw+UNAF/IfkhC5B6fUlhZydfnzlHQlywJ2symm/qXLE1x\n/Xpe79lned611/IpYft22jHz59OKmTOHPxUVPD4wwCeKNWs40axcSatleJjjNaOYz57NZloVFdP3\n/YmZQVaC7pzbDWDSZHchsiXbak8v0mvW0PLwTajeeYdWSDqpf/GVnL6bYl0dBfjhhynqd9wRLEou\nXnyhHbNxI6+3axej9dtuu/Dv6u4G3n038NfnzuU9jx9n1osQYdRtURQ8k+3okwpdXRTlpiY2vlqw\ngMU9AwP01H3q34YN49vZAhTVLVuAHTuAv/orphL6plsAJ4jubv47vPiZaCHzoYeAq6/mxLF/P731\ngQHg05/m+5s2jd9RaMUKntvQwAmpspKLqeXlfC1EGDXnEgVPptWPYcLpjE1NjIbD2SeDgxTK++7j\nb+fYjXDOnMA6uekm+tq+8rS+npNCdXVQuHTs2IWLnz4X/bHHmHHz0kvAK69wErn8ckbv3/42cMMN\n9PHDnv2GDcCPfsQWBmfPMs2xtJQ7H115Zc6/ajHDkaCLgifV6seJbJlEmxyXljJiBhh1HzzIRcqq\nKgq93/ThootoqzQ1UYxjMUbdAKPsmprEi5+jo7RL+vt5zqFDwL/8S9BU6/x5Wj5mnByOHOF4/YT1\n5JMsOKqvZ6Q/dy7H0dLCsU+0wYUoTtScSxQ8vq94uPqxr49RdWMjBbysjAuLvrw+UVOqZA2oHnmE\nrQDMGG0DvPboKKPiO+6gRw4A//ZvvE53N0U5FmNUfvIk7ZG6OuD994E33+SYh4bo1c+aRWulooIR\neyzG6w0M8FqXXMK/6dpraa8cOsSngauvZjXo6dMsTLroIo7LP2nEtwlQ1ks0mZbmXEJMB4mi68OH\naYv4Pim+CZfvapjIlkmWztjVFfR88VRW8npmgVXT3U3Rff/9IFums5ORc3U1fe0bbgBefhn4zW8Y\nRY+O8uf8eVo7IyP8zKlT/Jwv6TejQPf0sCPj2bP8+15+Gdi5kxkvCxbws+vXc8IJW0HhjZ8l6sWL\nFkVFwZOo+vHii9kS11de+m3T2tuDzyXa7i0RLS1BAY/n7Fkeu/baoPL0wAHe2+etnznDCPvtt3n/\n3t6gV/m8ebxOWRmv4xyF3be/7e+n3XLyJN8fHqYov/UWr3n6NEXePyUcOkQL6ORJ/s0HD/LvnTeP\n/66r07ZxQhG6mCHER9cPPEDB9tTVUQjDTbhSbUq1YQMXKQ8epPCaUcSXLwfuuovnbNtGAV64kNHy\n8DCjZd8DZmSE1ouP4puaKNDOBWJeXs5rjY6OF+u6Ok4QpaX03OvqeP2KCn6mtJRCXl3NSB8Yv22c\n/5tT7aqYzw0/xNSiCF3MSOJTC9vaGCGXl1Mo/b6Zky0cenErLeVi5YkTFNkVK/gUsGkT39+wgdu8\nrV7NTZfffZfRtnO85/AwLZH2dgptVRVtmepqZsqY8XVjIyP04WF+dtas8W0Aysv5eb9/qBfw8+f5\n3rlztGTeeosLsj09gbCnMoHlIgVUFC6K0MWMJN5Xr6hguXxLS7BIWlkJfPObFPdYbPz+nR0dzCLZ\nsiXYeu3SSyno69ezqrO6OvDsfQXo9u2B1dHTw4XY2bMpuKOjnFQ+8AF6/B/+MIuGfOm+LwpqbGS0\n71zQxOvkSU4Q58/zfG/ZnD4dVJC2tTEXvreXY92/n+defnkwgU3WVTEXKaCicFGWi5ixJMta8VGo\n3xBiYIALmbW1FNXPfIb7er7+OoUW4Dnr1lFgX3mFkXh8Vk1NDX37xx9njvnZs5xEFi0KMlvmzwc+\n+1med+gQ8OqrjKa7uym4S5cy0t67N9htqL+f1zp/nq9rajgBdXdzUonFmBO/d29QENXXF/j1s2Yl\nrjJNxAMPBAvHntFRToJf/3rO/xOJHKEsF5EyM91TjY9JfBS6Zw/fO348iH5LS4G/+Rvg1luDzBYv\n6u3tzFJ5913u0el3Herr42RQUsIFUN8m4MAB+uaxGG2V5mZmnPi9QW++efy4/Pfst7+rqOB3XlLC\na5SWsmq1uZn+/OrVwaYWS5YwWl+5MrEYx+9Hmoz4fvGANsCIEhL0IidZg6lCT3+baNy+EKmvj4JY\nURHkgcdiFOKuLv777Fn63X7T5v5+LnwePszovrqa4tfbG2yM4cXwyisp/m++SaG96SYujsaPM36y\nBIA33uD1vQUzOMgJx2+2UVl5YcT8yCPZi3GiFFBtgBEdtCha5IQ9VZ+/PRPS3yYadzgKPXmSYj48\nHGyqPG8eo++2NqYeDg5S7MvLaZM0NAA//CGj5HPnKPqjoxTAzk7ev64uaGO7aBEj+qqq8eKabAFy\n2TJG41ddxcj+xAmK+c03M3JPJtLxm3ekuvAbJpwCum8f7aWTJ/m9aWF05iNBL3J806owqeZvTzUd\nHYxKH3hgfMMq39f8ueeY8eEbY/lxe+FrbqZw9vXRXvGpjStXBvnhH/oQBXTPHi5U7t/P4wsW8Jz9\n+zkZrFtH2+PYMd6rrY2Lort3M3L/+c85GYTFNdmkc+gQRXXJEua5L10KfPKT4xc3E4l0rnYjam3l\n9WtqaOusXKlsl6ggy6XIKTRP1VsU+/bR9lixYnzDKp9pUllJ73twkAuUfkGzuXn8rkWrVzPTZO7c\nwJ4pLeXC6KFD9MRLSmiX7NkDvPce8NOfMlJesiRIN2xq4rlHjrAfS2UlI9tz55jeCARevGeiHjTh\nvPqwLdPcPH5/0HhytRuRsl2iiQS9yMnWU03mEWeyyBr2xf2OQvv2BW1vAWaYrF4d9DX3u//s2sVC\nID9uL3z33ps8G+bmm4EHH2RB0XPPcdGxsZGff+cd+uEXXUT/vKODGSvXX89JZOdOLrh+4hO0XQCO\nOSyIqU6WicR906apXaBOteGZmFlI0IucRHtwThQhhkm0MPnwwxS6ZcvSX2QNR40+d3xwMCjnP3CA\nPVLMaE+sW8f3/EYP6e6x2dHBplxnztB+GB4O2t/OmjVeiDs7mQHjr+NL9nt6gnPiBTHdyXI6F6gL\n7clM5AYJusj4MT7RY7v3mNeuDY75cye7RzhqjMUoln199Mh37+ZiZlMTj3mb5cYbgxzxVCchP3kd\nPswI/+hRTkJDQxR3L/AHD/L6f/qnjJjDEW0qrQZaW2kRPf44s2EWLgTuvjv5OKfTBlG2SzTRoqjI\nmEQLqmfPUhjDpNMk6/BhLnR2dNBG8RGw71J4xRVBk6sDB9LL9IjPOjl2LMgaGRxkVO77q5SUBD75\nk0/SYtmyJViATaXVQEcH/f7Vq4HPf56/t29PvvA4nQvUuVpgFYWFInSRMYke231b2TDhyHWiIqZl\nyxgJ19czYq6p4WfLy+ltx2JM71u3jtkn775LG2T16tR85/gIeP58LoLOm8do+9Qp2ig1NVxE/ehH\nuVA6PMxKzWefpSDfdBN9+3CrgURWlb/f0BDw619T9MvLWVj04IOpfZ9TaYPkaoFVFA4SdJExiR7b\nGxspxr668vBhlqwvWUIR85tQJPKIDx2iQB85Qrvjooso8n19tCt88Y+vsvzIR4IxxPvO69fzeuGJ\nI34hsK0N+NWvKLhz51LQfV/yefNYzh+LUdAvvpjX3LULePFF4PbbacVMJIhdXYz6X3ghKFAaGKBv\n/7nPXfhZ2SAiW2S5iIxJ9Nh+//0UOl+4sncvUw9XrmS0+/rrFNBERUxdXRT7G29kfvYll3CCmD2b\nQtjbS+EN2xuJcr3Pn+fibHxBj+9DHqakhNk0FRW89ty5jIhbWngfsyBibmoCbrmF6wP33jt5dNvS\nQu+/upo/ZrxfU1Piwi3ZICJbFKGLrEj22O63dlu0KBDE8CYUPg0xnBniLYehIf7ev58WzuLF9M73\n7mXEXFMT2Bvxi5VA0J88fnFxcJATgb/vrl1BtN/QwM+8+ipL+a+/nmLsuy960rFANmwAvvtdWjvO\n0Zs/c4bFTMl8cdkgIhsk6GLKiLc4SkpowZw5w9dtbUExEEABDG/WvGRJUMxz9dXAt751odjNmsXF\nyqEhRtoAnwwaGoJ+4b29FPVYDLjzziDrpLeXYt3QEDThuvxyivbICJ8qjh7lGEdHL7RAkq0HhI/H\nYsHeonV1bNPrfXohco0EXUwZ4UU+v5PPqVOMjgcGuMB42WW0aACKYUsLRXh4mHbLRz4SCGC8mHd0\nUDR7eynsr73G4iC/j+cPfsDIvqGB5/hiodWraets2cKof/16vgaCFEjfvTBZFWeynPFwz/QFC4IF\n0RtuCCpe5YuLqUKCLqaM8CLf/v0UypaWoHthby9TD72f3NrKUvpbbrmwReyvf80FzHA+96FDXDRd\nuBD4/vc5WVRW0nPv7+e13nqLE0R/Pz30N94APvhBXnfRIk4CTzzBylPfFiAstn4S8aLux5osZ9xX\nsvrX/vOdnRxbOoVbQqSLFkXFlOEX+QYHuXv9229T6JYupbgtXUqRDzeGit9aDuAuPc8/T0EdGuK1\n7rkH+NnPgrYAzlGgFy7ktWtrmSLY2cn3fHn+4cN8WvC9zhctotAfO8ZslIEB+vK+GViyjom+JUGY\n2lpOOPHHly6l9dLcHEwKaoIlpgIJuphyBgfZKvayyyiyzzzDqBug0IWzXRK1iH3+eaYRHj9Oy6S+\nntd5+WUKtMc5RuXV1TzHjBPGlVcye8Vnm7S382f2bL6/ahVzzsvKKMhh4d68OXHHxN7eCyce30s9\n/vjhw/zRPp5iqpGgiynFWxNr1lDYnaMoHzsWWCE/+hEzSl59NXHqnu+sWF7OH78/Z2kpPfC+Pkba\n/f1Bg625c7n46ouUBgfZe7yujsVEJ05wwhgY4OJse3uQcx5Of/zXf03cpjcWS9yb/O67LzzuUzfT\n6TmfrHWwEBMhD11MKT7TpaQkaKZVVsaS+poaCnRlJSPe48cpXPGpe7/6FaPx+vrg2MAAi32WLaNY\nnzrF6Ly8nMJeX89rlJcz/x3gxsqNjcGmzADH1NTEYqGKisD77u6mEI+McHzxbXr9htOJFkwXLx5/\nfOlS/oSZqLPhTN1FSuQfCbqYUsKZLk1N/GluZs9xH30PDjKaXbAA+OpXKX7hNMC776bgnjzJyHtg\ngD/XXRds33brrfy9ezej/1WrWM357W+zQKmujn66b4F77bUUVJ+SWF4+Pue8vZ1PAJddxnslatM7\nUQ5++Hi6W8epV7nIFFkuYkpJ5ImXljLH21d9VlVRJLu6KMbxPvPNN1PoS0tp01RWAh/7GFsDOBeI\n38UXAx//OAW7vZ2ZLzU1/FxXF/PZS0sp+rt3U1QHBvjeypUcgxf4o0f5+5prGJVXVnJBdmgo/Ug5\n3a3jCnkXKVHYKEIXWTNRw61k/da3baNo++jzl7+kLTN/fuAzA0FU+oUvAB/+8IUbVWzaNF78uruZ\ngeLtlzlzGLmXlPBn9my+V1JCL/3yy4PNmMN/R1MTJxZf0erb9k7WpjfZd5FOz3n1KheZIkEXWTFR\ngU18c6x4AXv4YUbkZ88yP7x51Ni1AAANyUlEQVSxkWXxnnifOZHFES9+7e3BxNDfzzz0wUFG3L7h\n1vnztF9qa4Ff/IL3TVTt+dhjQZOxVAqCJvO+U43q1aRLZIoEvUiZKKpOh0R+78GDwH33MaJsbKRN\n8ZWvMOvk5Emm9m3cSLsECPbtjO+jHh+VJhpzvPi99x4XXRsamIJ45gxTFYeHKerV1fzsuXO83vnz\nQTphvACnszlFsu/CH0/nu81mFylR3GQt6GZWCmAHgC7n3K3ZD0lMNbnMoojv19Ldzbzxc+cYJQ8O\nMgo+coQR8ooVjHofeog++Mc/Hnxu+3YuOt5yS+K+KcnGHBa/xka+f+QIRT0WY8l/WRnF22fVnDsX\npDV64R0aYjfI++7jounRo6z6bGuj5/61rzHTJVHr20TfBZD5Pp1q0iUyIRcR+pcBHABQO9mJojDI\nZRZFIstjeJhCasaI2Odv+/axZWXMF//Zz2i3tLXRo77pJmazJIpK4zeL6O3lAuaTT1Jow5ssP/YY\nxdh3OWxq4j17exmNDwzwdUVF4JF3d1Osq6o4zj17eH5VVVCE1NTE42fPJp785H2LfJNVlouZLQDw\nSQDfyc1wxHQwWRZFOkUt8RkcR48yCo7FgnOGhhgdV1dT4A4fppgODwf53d3dFM/bb+ciZXy/8a6u\n4NyzZ4MJY+vW8ePzdkVTU3DNdeuYftjczBTGRYuAD3yAE8m8efxcezvHV1ISFBjV1/N+PmXRWzfJ\nioLSzWYRItdkm7b4dwD+DMBoshPM7B4z22FmO3rCW6SLvJGoX4qPJJP1Lkkm6vGVnU1NXNgsKQkq\nQ8vKKHB+y7fycgqkr/r0+d0TiZ/fLMKLqxmPNzZeKK6trbR0rr6a6YgNDfzbhodZ4n/bbTy+ZAk/\n39fHaBygDdPWxijbOR732+oNDvJ4shRCbVAh8k3GlouZ3QrgmHNup5l9NNl5zrlHATwKAGvXrnWZ\n3k/kjomyKDKxY7zf29HB3idbt1J4h4YolhdfTDH0vjXAiH3jRlZiejGdSPz8ZhG+EdfgIK2TZJtF\nxC8sLlnC+/nMm+bmoG3vtm28pnNB5ShAT3/2bN4L4P1Wr57YRpH3LfJJNh76OgC3m9knAFQCqDWz\nJ5xzn8/N0MRUMVEWRaIdgFJZ2AsvWt5yS1CxuXEjcNddTB98/HHaJXPm0FrxHRBTye9ubeW19uzh\n+XV1FNeJNotIJK4335z4PL+5xq5dnIgqKhi9r1gB7NwZpFRWVCiFUBQuGQu6c+4vAfwlAIxF6PdL\nzPNPqumIySLJTBf2wpG9r9oMC3VrK8U0LPyJdgGaiM99jhNCfX32+dnx39OyZYGN44nFgD/8w+Dv\nO3Jk/PZ3QhQaykOPELlIR8y0qCXVlL1sc6yrqoBnn+W/ly1j+f+mTenl0if6nh5+mNG43/wC4IS0\nbdvkG0LnKqdfiGzJSS8X59xzykHPP+EoOdU2rfFkurA30UJronts2JDehg8dHRTd3bv5emSE6Yun\nT6ffYzzR9zQycuHkk0r/lHQXkYWYShShR4hMCluSRZfpRpjpRPaZPEk8+SSLfubN4+f27+fn2tv5\nmXRy6RN9T42NQb68J12rCVBnRJFf1G0xQqQTJQO5jS7TiewzeZJ48UWe41MWR0Y4cbz5ZnBOqh0J\nE31PCxYwvTLdHHJ1RhSFhCL0CJGu/53r6DLVyD7VJ4nw08M773Chtbqa75kxoh4ZYafGtjZmoKRS\nlZnoeyotBe6/f3xaYyrevqpDRSGhCD1CpOt/5yu6TOVJIv7pobmZTb+6u3muL+GfN4+/t2+nGKdS\nlRn/PQ0MMPL/5S/5/u/93uQLoR5Vh4pCQhF6xEjH/56u6DJRiuD27Xwv2ZNE+Omhu5vVmiMj9M59\nGX59fVABGovx2qn+7eFiKO/nNzWlnxmkzoiikJCgFzHT0Xc70QLo9u3j+6UnEkFvy/imWdXV3Gi6\noyMoz7/22qCqc3Q0s66GubCdVB0qCgUJehEwUSbLVESX4fsdPkxhjhfMQ4doayTDPz34plnV1fTN\n166lRQIEYg5k/mSRy5a3QuQbCXrESWUXnVQFPJUCmvj7Pf88e7XU1gYCnIpg+qeH997j5wYHGZmv\nWsVofOvW9HYTSoYWNUWUkKBHnGwshbCAl5WxNe7SpRPnjsff7+KLKejt7YGgpyKY/unh7bdpu8yf\nTzH3e3tu3BgsambzZKHt3kSUkKBHnEwthXCkPWsW8NRTFNIVK4BrrgnEOTwxdHQATz9NayQWo8/d\n1gb85jeMtNPt3eLb4Ppx1NYGWSS5akurRU0RJSToESfbZltDQ8ALL/B3LMbNlc+dY5vZhoZgYvAT\nQGUlBd1vRrFuHSeBzs7MBHM6BFeLmiIqSNAjTrbNtn79ay5Izp3L9MCREb5ubx9fyOMngDVrxu/y\ns2sXsHw5I+1MRVOCK0RqSNAjzmQRbrKFTh/Z+97j8+cDBw5QqCsraaEsXBhMDH4CKClhVN7ennzj\nCnUnFGJqkKAXAcki3IkyYHxkX14ebKrc1BQU+jQ2jhfqsLXT1BQsXsZvXJGLFr9CiMSo9L+ImahJ\nlo/sV66kgI+OUuRvuIF7dcZbKKmWwOeixa8QIjES9CJmsl4ura3Agw9yAwm//2ey/jCp9pFRd0Ih\npg5ZLkVMqhkwqS5KpnKeCnmEmDoUoRcx+egUqO6EQkwdEvQiprWVTbJeeQV44gn+Xr9+ahcnM93i\nTggxObJc8ki+0/c6Otj5cPVq4MYbg06IixdPvahLwIXIPYrQ80QhbC6sjBMhooUEPU8Ugpgq40SI\naCFBzxOFIKbpbiothChs5KHniWzT93Lhv6t1rBDRQhF6nsgmfS9X/rsyToSIForQ80Q2bWFzsQ9m\neBy5EPB8Z+wIISToeSVTMS20fTDVcEuIwkCWywyk0BYzCyFjRwihCH3aieJiZqE9MQhRrChCn0ai\nuphZaE8MQhQritCnkUJczMwFhfbEIESxogh9GimEYqKpoNCeGIQoVhShTyNR7gVeSE8MQhQritCn\nEfUCF0JMJRkLupktNLNnzWy/mb1mZl/O5cCiiKwJIcRUko3lMgLgT5xzu8ysBsBOM3vGObc/R2OL\nJLImhBBTRcYRunPuqHNu19i/TwE4AKAlVwMTQgiRHjnx0M1sMYDVAF5K8N49ZrbDzHb09PTk4nZC\nCCESkLWgm9kcAD8E8EfOuf74951zjzrn1jrn1jY0NGR7OyGEEEnIStDNbBYo5pudc0/lZkhCCCEy\nIZssFwPwGIADzrlv5m5IQgghMiGbCH0dgN8F8DEz2z3284kcjUsIIUSaZJy26Jz7DQDL4ViEEEJk\ngSpFhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQ\nhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAi\nIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQ\nhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiIkjQhRAiImQl6Gb2O2b2upkdMrO/yNWg\nhBBCpE/Ggm5mpQD+AcDHAVwB4LNmdkWuBiaEECI9sonQrwFwyDn3hnNuGMD3ANyRm2EJIYRIl7Is\nPtsC4N3Q604A18afZGb3ALhn7OVpM3s9i3vmgosAvJ/nMRQK+i4C9F0E6LsIKJTvYlEqJ2Uj6Cnh\nnHsUwKNTfZ9UMbMdzrm1+R5HIaDvIkDfRYC+i4CZ9l1kY7l0AVgYer1g7JgQQog8kI2gvwyg1cwu\nNbNyAJ8B8HRuhiWEECJdMrZcnHMjZvYlAFsAlAJ43Dn3Ws5GNnUUjP1TAOi7CNB3EaDvImBGfRfm\nnMv3GIQQQuQAVYoKIUREkKALIUREKGpBN7M/MTNnZhfleyz5wsy+YWbtZrbXzP6vmdXle0zTjVpY\nEDNbaGbPmtl+M3vNzL6c7zHlEzMrNbNXzOwn+R5LqhStoJvZQgAbAbyT77HkmWcAXOmcWwHgIIC/\nzPN4phW1sBjHCIA/cc5dAeA6AP+1iL8LAPgygAP5HkQ6FK2gA/hbAH8GoKhXhZ1zW51zI2MvXwTr\nCYoJtbAYwzl31Dm3a+zfp0Axa8nvqPKDmS0A8EkA38n3WNKhKAXdzO4A0OWc25PvsRQYdwP4eb4H\nMc0kamFRlCIWxswWA1gN4KX8jiRv/B0Y8I3meyDpMOWl//nCzLYBmJ/gra8A+G+g3VIUTPRdOOd+\nPHbOV8BH7s3TOTZReJjZHAA/BPBHzrn+fI9nujGzWwEcc87tNLOP5ns86RBZQXfObUh03MyuAnAp\ngD1mBtBi2GVm1zjn3pvGIU4byb4Lj5n9AYBbAax3xVeYoBYWIcxsFijmm51zT+V7PHliHYDbzewT\nACoB1JrZE865z+d5XJNS9IVFZvYWgLXOuULoqDbtmNnvAPgmgBudcz35Hs90Y2Zl4GLwelDIXwZw\n1wypes4pxghnE4ATzrk/yvd4CoGxCP1+59yt+R5LKhSlhy7G8S0ANQCeMbPdZvaP+R7QdDK2IOxb\nWBwA8P1iFPMx1gH4XQAfG/t/YfdYlCpmCEUfoQshRFRQhC6EEBFBgi6EEBFBgi6EEBFBgi6EEBFB\ngi6EEBFBgi6EEBFBgi6EEBHh/wOSGrCTQqVtEgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "gOOfu2-qKvPH",
"colab_type": "code",
"outputId": "730d03dc-d860-4b29-93b5-533fa940d506",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Batch of independent Bernouilli distributions\n",
"b3 = tfd.Bernoulli(probs=[.3, .5, .7])\n",
"b3"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Bernoulli 'Bernoulli/' batch_shape=[3] event_shape=[] dtype=int32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TzcwgjIVLNUA",
"colab_type": "text"
},
"source": [
"Batch shape [3] means the distribution object is a batch of three distributions and event shape [ ] means each distribution is scalar. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "2oeUmnCoL66L",
"colab_type": "code",
"outputId": "f29ca200-0335-430e-f3d9-9f24a123879a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"b3.sample()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=460, shape=(3,), dtype=int32, numpy=array([0, 0, 1], dtype=int32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "SBXCpSsCMyWP",
"colab_type": "code",
"outputId": "2c5b3750-bd54-4d06-f3c0-c3f810ff7cae",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"b3.prob([1, 1, 0])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=476, shape=(3,), dtype=float32, numpy=array([0.29999998, 0.5 , 0.29999998], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kvs823cVNaXW",
"colab_type": "text"
},
"source": [
"The function prob(v) returns a vector where the ith entry is the probability of the ith coin being v[i]. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MPfAVkh4p5YK",
"colab_type": "text"
},
"source": [
"## tfp.distributions.Independent\n",
"\n",
"In order to create a joint distribution you can use Independent function. The prob function applied to a vector for such a distribution will be a number unlike the b3 example above. The Independent function convert a batch of distributions into a single joint distribution of independent distributions. The batch shape becomes the event shape. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "LOqWcHT0OuoB",
"colab_type": "code",
"outputId": "26edee47-ee76-41d3-df20-cb96801da604",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"b3_joint = tfd.Independent(b3, reinterpreted_batch_ndims=1)\n",
"b3_joint"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Independent 'IndependentBernoulli/' batch_shape=[] event_shape=[3] dtype=int32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0gnU4S5OP9NH",
"colab_type": "code",
"outputId": "a77e60fd-7dd9-4a18-9266-e9183f611157",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Prob(1,1,1) = 0.3 x 0.5 x 0.7\n",
"b3_joint.prob([1,1,1])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=516, shape=(), dtype=float32, numpy=0.105000004>"
]
},
"metadata": {
"tags": []
},
"execution_count": 21
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "n0ppRSAUKDgm",
"colab_type": "code",
"outputId": "b54d3cec-7c0f-4e50-c20f-d0ed95bcf734",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Make independent distribution from a 2-batch Normal.\n",
"n = tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5])\n",
"ind = tfd.Independent(\n",
" distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]),\n",
" reinterpreted_batch_ndims=1)\n",
"ind"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Independent 'IndependentNormal/' batch_shape=[] event_shape=[2] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "25jav7oUKfE9",
"colab_type": "code",
"outputId": "120dc37a-9955-48ef-87de-2e3df4f847f6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Normal 'Normal/' batch_shape=[2] event_shape=[] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 23
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hn3bOgaZKgff",
"colab_type": "code",
"outputId": "3b256463-a20e-47aa-a45b-9d00838bc047",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"n.sample()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=545, shape=(2,), dtype=float32, numpy=array([-0.9780394, 0.8697878], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 24
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "B2stBhNFKR8F",
"colab_type": "code",
"outputId": "8a0577c3-886a-455c-a0df-99d9bd9e5b8c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"ind.sample()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=581, shape=(2,), dtype=float32, numpy=array([-0.9398609, 1.255456 ], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Nl9EteMfKR4a",
"colab_type": "code",
"outputId": "941feb40-d33e-4c21-a090-240364e84012",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# Make independent distribution from a 2-batch bivariate Normal.\n",
"ind = tfd.Independent(\n",
" distribution=tfd.MultivariateNormalDiag(\n",
" loc=[[-1., 1], [1, -1]],\n",
" scale_identity_multiplier=[1., 0.5]),\n",
" reinterpreted_batch_ndims=1)\n",
"\n",
"ind"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Independent 'IndependentMultivariateNormalDiag/' batch_shape=[] event_shape=[2, 2] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OrMkbXdyKRxM",
"colab_type": "code",
"outputId": "001c382a-d5c4-46c0-a68a-279aa9f7b9cb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"ind.sample()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=700, shape=(2, 2), dtype=float32, numpy=\n",
"array([[-2.382836 , -0.77324414],\n",
" [ 1.5233893 , -1.3743069 ]], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E76sKBRhUqfJ",
"colab_type": "text"
},
"source": [
"## Batches of multivariate distributions"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hVImLaBjQI9C",
"colab_type": "code",
"outputId": "a8cb9a06-40a6-4012-9cff-95372e4e071f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"nd_batch = tfd.MultivariateNormalFullCovariance(\n",
" loc = [[0., 0.], [1., 1.], [2., 2.]],\n",
" covariance_matrix = [[[1., .1], [.1, 1.]], \n",
" [[1., .3], [.3, 1.]],\n",
" [[1., .5], [.5, 1.]]])\n",
"nd_batch"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.MultivariateNormalFullCovariance 'MultivariateNormalFullCovariance/' batch_shape=[3] event_shape=[2] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 28
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wttOT_o0U7Ps",
"colab_type": "text"
},
"source": [
"nd_batch is a batch of 3 independent normal distributions where each distribution is 2 dimensional and the elements of the individual distributions are not independent. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "eLoUi6KNVRiq",
"colab_type": "code",
"outputId": "ee0d5a03-efcd-408c-b6f4-c4312d99f50b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
}
},
"source": [
"nd_batch.sample(4)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=800, shape=(4, 3, 2), dtype=float32, numpy=\n",
"array([[[-4.9703220e-01, 1.1425320e+00],\n",
" [ 2.5317080e+00, 1.3575239e+00],\n",
" [ 2.6210568e+00, 2.3021798e+00]],\n",
"\n",
" [[-1.1117331e+00, 9.0224606e-01],\n",
" [ 1.2685717e+00, 2.8281617e-01],\n",
" [ 3.5086212e+00, 1.8443767e+00]],\n",
"\n",
" [[ 8.9322239e-01, -7.6503921e-01],\n",
" [ 9.1081977e-01, 1.4396188e+00],\n",
" [ 1.3879120e-01, 6.2994027e-01]],\n",
"\n",
" [[ 6.8450880e-01, 8.3123893e-04],\n",
" [ 1.4884772e+00, 1.5386418e+00],\n",
" [ 1.6536089e+00, 1.0512571e+00]]], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 29
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BryrnlWMIhlt",
"colab_type": "text"
},
"source": [
"## Broadcasting"
]
},
{
"cell_type": "code",
"metadata": {
"id": "iERqIocoIEkX",
"colab_type": "code",
"outputId": "45b99ddc-62ac-41b2-ca06-6d1b53ba0fb8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"n = tfd.Normal(loc=0., scale=1.)\n",
"n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Normal 'Normal/' batch_shape=[] event_shape=[] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 30
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "oTSzey2BImS3",
"colab_type": "code",
"outputId": "0afddd24-cd6f-476c-b5fd-22375a59cf1a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"n.log_prob(0.)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=818, shape=(), dtype=float32, numpy=-0.9189385>"
]
},
"metadata": {
"tags": []
},
"execution_count": 31
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "EGhcrNiVInMD",
"colab_type": "code",
"outputId": "f05d10ba-5db6-41d2-c487-2277fb569dfe",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"n.log_prob([0.])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=829, shape=(1,), dtype=float32, numpy=array([-0.9189385], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 32
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z1fUyS3rInUB",
"colab_type": "code",
"outputId": "d6f3f4a1-d45e-4f29-e909-b0c4d5e074ff",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"n.log_prob([[0., 1.], [-1., 2.]])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=840, shape=(2, 2), dtype=float32, numpy=\n",
"array([[-0.9189385, -1.4189385],\n",
" [-1.4189385, -2.9189386]], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 33
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "b8Dphj03InXX",
"colab_type": "code",
"outputId": "d87823f1-7638-4bd3-9c79-0faf8dcf6348",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"nd = tfd.MultivariateNormalDiag(loc=[0., 1.], scale_diag=[1., 1.])\n",
"nd"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag/' batch_shape=[] event_shape=[2] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 34
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "BSzqUiKLIndC",
"colab_type": "code",
"outputId": "c0d0bad3-3226-4781-be3e-aa6727bb7db6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"nd.log_prob([0., 0.])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=931, shape=(), dtype=float32, numpy=-2.337877>"
]
},
"metadata": {
"tags": []
},
"execution_count": 35
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "GGeQwQkLIna4",
"colab_type": "code",
"outputId": "5a651846-5522-4b88-a177-9bb52e8fdced",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"nd.log_prob([[0., 0.],\n",
" [1., 1.],\n",
" [2., 2.]])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=982, shape=(3,), dtype=float32, numpy=array([-2.337877 , -2.337877 , -4.3378773], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DCh8mkk2iUIj",
"colab_type": "text"
},
"source": [
"# Distribution Lambda\n",
"\n",
"This creates a distribution from the output of the previous layer. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "G-LSLeUuoLuT",
"colab_type": "code",
"outputId": "cfd6e53c-8ecd-4feb-d442-4ad0aa19d561",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"layer = tfp.layers.DistributionLambda(lambda t: tfd.Normal(t, 1.))\n",
"distribution = layer(2.)\n",
"assert isinstance(distribution, tfd.Normal)\n",
"distribution.loc\n",
"# ==> 2.\n",
"distribution.stddev()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=1013, shape=(), dtype=float32, numpy=1.0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 37
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "JU5GnItuiK1C",
"colab_type": "code",
"colab": {}
},
"source": [
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(2, input_dim=2),\n",
" tfp.layers.DistributionLambda(\n",
" make_distribution_fn=lambda t: tfp.distributions.Normal(\n",
" loc=t[..., 0], scale=tf.exp(t[..., 1]))\n",
" )\n",
"])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-897JKTvifhX",
"colab_type": "code",
"outputId": "cc97598f-87e0-4016-f3d2-f2cfbdd6bb73",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"x = np.array([[0,1], [2,3]])\n",
"model(x)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Normal 'sequential/distribution_lambda_1/Normal/' batch_shape=[2] event_shape=[] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 39
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ODOkx1N_iiWR",
"colab_type": "code",
"outputId": "89f2b4e7-27d3-4b88-aba0-dbfd12837876",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
}
},
"source": [
"model(x).sample(5)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=1176, shape=(5, 2), dtype=float32, numpy=\n",
"array([[ 1.2495792 , -7.5877295 ],\n",
" [ 2.3519795 , -26.219402 ],\n",
" [ 1.673234 , -7.4577127 ],\n",
" [ -5.0821986 , 17.405872 ],\n",
" [ 0.36446828, -9.862875 ]], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 40
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "y1xFwwxZjlyO",
"colab_type": "code",
"outputId": "d272d62c-9752-4a04-a53a-592e4616e130",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 221
}
},
"source": [
"model.summary()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense (Dense) (None, 2) 6 \n",
"_________________________________________________________________\n",
"distribution_lambda_1 (Distr ((None,), (None,)) 0 \n",
"=================================================================\n",
"Total params: 6\n",
"Trainable params: 6\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CzeZ_QXrnbZF",
"colab_type": "text"
},
"source": [
"# Variable Layer\n",
"\n",
"This is basically a layer with a trainable output independent of the inputs. Think of it as a usual Dense layer with weights forced to be zero and the bias is the output no matter what the input is. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "YUVzDmDZo2rh",
"colab_type": "code",
"colab": {}
},
"source": [
"vl = tfp.layers.VariableLayer(\n",
" shape=[3, 4, 2],\n",
" dtype=tf.float64\n",
" )"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BOe4KNn5o-eb",
"colab_type": "code",
"outputId": "3d60b1b6-9cdd-414b-f48d-8cad2085a4b2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 272
}
},
"source": [
"vl(4)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=1189, shape=(3, 4, 2), dtype=float64, numpy=\n",
"array([[[0., 0.],\n",
" [0., 0.],\n",
" [0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.],\n",
" [0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.],\n",
" [0., 0.],\n",
" [0., 0.]]])>"
]
},
"metadata": {
"tags": []
},
"execution_count": 43
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "G-hOFPvdpEbk",
"colab_type": "code",
"outputId": "ac1baf84-29f6-44fe-f854-f070d8c74c49",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"vl = tfp.layers.VariableLayer(\n",
" shape=[3, 4, 4],\n",
" dtype=tf.float64,\n",
" initializer=tfp.layers.BlockwiseInitializer([\n",
" 'zeros',\n",
" tf.keras.initializers.Constant(np.log(np.expm1(1.))),\n",
" 'ones'\n",
" ], sizes=[1, 2, 1])\n",
" )\n",
"vl"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow_probability.python.layers.variable_input.VariableLayer at 0x7f8347c6c710>"
]
},
"metadata": {
"tags": []
},
"execution_count": 44
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hvgag8vRpO4b",
"colab_type": "code",
"outputId": "551fcce0-33b8-47e4-dfb4-784f41856615",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 272
}
},
"source": [
"vl(1)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=1211, shape=(3, 4, 4), dtype=float64, numpy=\n",
"array([[[0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ]],\n",
"\n",
" [[0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ]],\n",
"\n",
" [[0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ],\n",
" [0. , 0.54132485, 0.54132485, 1. ]]])>"
]
},
"metadata": {
"tags": []
},
"execution_count": 45
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "IHPYpyKIle8q",
"colab_type": "code",
"colab": {}
},
"source": [
"trainable_normal = tf.keras.models.Sequential([\n",
" tfp.layers.VariableLayer(\n",
" shape=[3, 4, 2],\n",
" dtype=tf.float64,\n",
" initializer=tfp.layers.BlockwiseInitializer([\n",
" 'zeros',\n",
" tf.keras.initializers.Constant(np.log(np.expm1(1.))),\n",
" ], sizes=[1, 1])),\n",
" tfp.layers.DistributionLambda(lambda t: tfp.distributions.Independent(\n",
" tfp.distributions.Normal(loc=t[..., 0], scale=tf.math.softplus(t[..., 1])),\n",
" reinterpreted_batch_ndims=1)),\n",
"])\n",
"\n",
"# The output of the model is a distribution and so rv_x below will be a distribution\n",
"# and rv_x.log_prob(x) makes sense.\n",
"\n",
"negloglik = lambda x, rv_x: -rv_x.log_prob(x)\n",
"trainable_normal.compile(optimizer='adam', loss=negloglik)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3kLrnFr_RHcw",
"colab_type": "code",
"outputId": "ab7a72d2-4033-4c3f-bcd3-acd427fb02a9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"x = trainable_normal(0.) # `0.` ignored; like conditioning on emptyset.\n",
"x"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Independent 'Independentsequential_3/distribution_lambda_4/Normal/' batch_shape=[3] event_shape=[4] dtype=float64>"
]
},
"metadata": {
"tags": []
},
"execution_count": 54
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ObcWuDmYRObs",
"colab_type": "code",
"outputId": "8011f7dc-349b-49d6-941f-23b6c28d1dbe",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"x.dtype"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tf.float64"
]
},
"metadata": {
"tags": []
},
"execution_count": 48
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "31T3hz40RUQ3",
"colab_type": "code",
"outputId": "ba35b397-2613-4da1-de62-810fdb4c0d76",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"x.batch_shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([3])"
]
},
"metadata": {
"tags": []
},
"execution_count": 50
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cYK7TH5NRWhh",
"colab_type": "code",
"outputId": "362b10d6-e0c6-47b6-a4a4-e0bfdfc2b687",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"x.event_shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([4])"
]
},
"metadata": {
"tags": []
},
"execution_count": 105
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "O71JWgrtRfBy",
"colab_type": "code",
"outputId": "897427cb-df68-47e2-c996-3746c84a1144",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"x.mean()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=1420, shape=(3, 4), dtype=float64, numpy=\n",
"array([[0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.]])>"
]
},
"metadata": {
"tags": []
},
"execution_count": 55
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "2PUq0_FKRhKj",
"colab_type": "code",
"outputId": "f71d2a6b-b0ed-4678-cb85-33bbadd6b903",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"x.variance()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=1426, shape=(3, 4), dtype=float64, numpy=\n",
"array([[1., 1., 1., 1.],\n",
" [1., 1., 1., 1.],\n",
" [1., 1., 1., 1.]])>"
]
},
"metadata": {
"tags": []
},
"execution_count": 56
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QbwBr_G3RoPK",
"colab_type": "code",
"colab": {}
},
"source": [
"m = np.zeros((3,4)) + 2\n",
"s = np.zeros((3,4)) + 0.5"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WuGa6GF_S4R0",
"colab_type": "code",
"colab": {}
},
"source": [
"gen = tfp.distributions.Independent(\n",
" tfp.distributions.Normal(loc=m, scale=s),\n",
" reinterpreted_batch_ndims=1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EVWrrMPsTGwc",
"colab_type": "code",
"colab": {}
},
"source": [
"x = gen.sample(100).numpy()\n",
"dataset = tf.data.Dataset.from_tensor_slices((x,x))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jfLnMXBNTJXL",
"colab_type": "code",
"colab": {}
},
"source": [
"dataset = dataset.shuffle(50)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "j5CyzevpTkYh",
"colab_type": "code",
"outputId": "e3d0e99e-6ed9-4a67-c14d-e4fdc9adaca5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"trainable_normal.fit(dataset, epochs = 50)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/50\n",
"100/100 [==============================] - 1s 5ms/step - loss: 11.3457\n",
"Epoch 2/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 10.2659\n",
"Epoch 3/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 9.4590\n",
"Epoch 4/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 8.8378\n",
"Epoch 5/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 8.3467\n",
"Epoch 6/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 7.9486\n",
"Epoch 7/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 7.6181\n",
"Epoch 8/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 7.3374\n",
"Epoch 9/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 7.0935\n",
"Epoch 10/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 6.8768\n",
"Epoch 11/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 6.6801\n",
"Epoch 12/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 6.4977\n",
"Epoch 13/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 6.3252\n",
"Epoch 14/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 6.1592\n",
"Epoch 15/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 5.9967\n",
"Epoch 16/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 5.8353\n",
"Epoch 17/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 5.6733\n",
"Epoch 18/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 5.5091\n",
"Epoch 19/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 5.3416\n",
"Epoch 20/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 5.1701\n",
"Epoch 21/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 4.9943\n",
"Epoch 22/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 4.8144\n",
"Epoch 23/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 4.6312\n",
"Epoch 24/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 4.4461\n",
"Epoch 25/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 4.2611\n",
"Epoch 26/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 4.0788\n",
"Epoch 27/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 3.9024\n",
"Epoch 28/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 3.7351\n",
"Epoch 29/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 3.5803\n",
"Epoch 30/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 3.4405\n",
"Epoch 31/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 3.3177\n",
"Epoch 32/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 3.2126\n",
"Epoch 33/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 3.1252\n",
"Epoch 34/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 3.0542\n",
"Epoch 35/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 2.9980\n",
"Epoch 36/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 2.9543\n",
"Epoch 37/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 2.9210\n",
"Epoch 38/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 2.8960\n",
"Epoch 39/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 2.8774\n",
"Epoch 40/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 2.8636\n",
"Epoch 41/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 2.8535\n",
"Epoch 42/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 2.8461\n",
"Epoch 43/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 2.8407\n",
"Epoch 44/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 2.8367\n",
"Epoch 45/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 2.8337\n",
"Epoch 46/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 2.8315\n",
"Epoch 47/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 2.8298\n",
"Epoch 48/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 2.8286\n",
"Epoch 49/50\n",
"100/100 [==============================] - 0s 1ms/step - loss: 2.8277\n",
"Epoch 50/50\n",
"100/100 [==============================] - 0s 2ms/step - loss: 2.8270\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f834749c5c0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 61
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3mBUt-55Tvsf",
"colab_type": "code",
"outputId": "40de7c6c-e22f-4a08-a2a0-abe4861a1de7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"trainable_normal(0.).mean()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=32211, shape=(3, 4), dtype=float64, numpy=\n",
"array([[2.01612234, 1.93767194, 1.97366429, 1.97893472],\n",
" [1.93390999, 1.97119836, 2.02611989, 1.95313365],\n",
" [2.06840718, 1.92072485, 1.99626747, 1.97700717]])>"
]
},
"metadata": {
"tags": []
},
"execution_count": 62
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6ZZFrFHsUqrh",
"colab_type": "code",
"outputId": "250a6a2a-2d5e-4fdd-8ee2-b694c384eccd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"trainable_normal(0.).stddev()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: id=32265, shape=(3, 4), dtype=float64, numpy=\n",
"array([[0.48401529, 0.5119396 , 0.51694257, 0.40709828],\n",
" [0.45146548, 0.49992078, 0.51552555, 0.56620307],\n",
" [0.51147882, 0.51625875, 0.51407853, 0.53042485]])>"
]
},
"metadata": {
"tags": []
},
"execution_count": 63
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NL5_9zm955DP",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment