Last active
November 8, 2022 18:37
-
-
Save ricardoV94/0cf8fd0f69a09d7eff0a5b41cb111965 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "7205a740", | |
| "metadata": {}, | |
| "source": [ | |
| "# Marginalizing discrete RVs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "6a238994", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import aeppl\n", | |
| "import aesara\n", | |
| "import aesara.tensor as at\n", | |
| "from aesara.graph import FunctionGraph\n", | |
| "from aesara.compile.builders import OpFromGraph\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "import pymc as pm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e05a3823", | |
| "metadata": {}, | |
| "source": [ | |
| "## Marginalizing a single RV" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "cf25b739", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "with pm.Model() as m:\n", | |
| " p = pm.Dirichlet(\"p\", [1, 1])\n", | |
| " x = pm.Categorical(\"x\", p=p)\n", | |
| " y = pm.Normal(\"y\", pm.math.stack([-1, 1])[x], 1, observed=1) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "3f94e5a6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "p_vv = m.rvs_to_values[p]\n", | |
| "x_vv = m.rvs_to_values[x]\n", | |
| "logp = m.logp()\n", | |
| "logp_op = OpFromGraph([p_vv, x_vv], [logp], inline=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "d1a4e4f7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "OpFromGraph{inline=True}.0" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "logp_op(p_vv, x_vv)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "6f626312", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'p_simplex__': array([0.]), 'x': array(0)}" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "m.initial_point()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "872648d2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(array(-4.30523289), array(-2.30523289))" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "logp_op(np.array([0]), 0).eval(), logp_op(np.array([0]), 1).eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "7367bbe3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x_domain = range(2) # Possible values of the categorical\n", | |
| "marginal_logp = at.logsumexp([logp_op(p_vv, x_vv_const) for x_vv_const in x_domain])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "7df0efd4", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array(-2.17830488)" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "marginal_logp.eval({p_vv: np.array([0])})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "35e97418", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array(-2.17830488)" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "with pm.Model() as m_ref:\n", | |
| " p = pm.Dirichlet(\"p\", [1, 1])\n", | |
| " y = pm.NormalMixture(\"y\", w=p, mu=[-1, 1], sigma=1, observed=1) \n", | |
| "m_ref.compile_logp()({\"p_simplex__\": np.array([0])})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "c4846826", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "f = aesara.function([p_vv], marginal_logp)\n", | |
| "# aesara.dprint(f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "c5f15010", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "52" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(f.maker.fgraph.apply_nodes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "37ef8d2b", | |
| "metadata": {}, | |
| "source": [ | |
| "## Marginalize multiple RVs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "81d63419", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def explicit_mixture(name, categorical_idx, components):\n", | |
| " return pm.Normal(name, pm.math.stack(components)[categorical_idx], 1)\n", | |
| " \n", | |
| "with pm.Model() as m:\n", | |
| " p1 = pm.Dirichlet(\"p1\", [1, 1])\n", | |
| " mix_comp1 = pm.Categorical(\"mix_comp1\", p=p1) \n", | |
| " y1 = explicit_mixture(\"y1\", mix_comp1, [-1, 1])\n", | |
| " \n", | |
| " p2 = pm.Dirichlet(\"p2\", [1, 1])\n", | |
| " mix_comp2 = pm.Categorical(\"mix_comp2\", p=p2) \n", | |
| " y2 = explicit_mixture(\"y2\", mix_comp2, [-2, 2])\n", | |
| " \n", | |
| " p3 = pm.Dirichlet(\"p3\", [1, 1])\n", | |
| " mix_comp3 = pm.Categorical(\"mix_comp3\", p=p3)\n", | |
| " y3 = explicit_mixture(\"y3\", mix_comp3, [y1, y2])\n", | |
| " \n", | |
| " pm.Normal(\"llike\", y3, 1, observed=9)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "bf922bc3", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mix_comp3\n", | |
| "mix_comp2\n", | |
| "mix_comp1\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "logp_graph = m.logp()\n", | |
| "rvs = list(m.free_RVs)\n", | |
| "marginalize_rvs = {mix_comp1, mix_comp2, mix_comp3}\n", | |
| "fg = FunctionGraph(outputs=rvs, clone=False)\n", | |
| "order = fg.toposort()\n", | |
| "for rv in sorted(marginalize_rvs, key=lambda x: order.index(x.owner), reverse=True):\n", | |
| " print(rv)\n", | |
| " rvs.remove(rv)\n", | |
| " vv = m.rvs_to_values[rv]\n", | |
| " vvs = [m.rvs_to_values[rv] for rv in rvs]\n", | |
| " logp_op = OpFromGraph([vv, *vvs], [logp_graph], inline=True)\n", | |
| " rv_domain = range(2) # Hard-coded\n", | |
| " logp_graph = at.logsumexp([logp_op(vv_const, *vvs) for vv_const in rv_domain])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "5c59cafa", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'p1_simplex__': array([0.]),\n", | |
| " 'y1': array(-1.),\n", | |
| " 'p2_simplex__': array([0.]),\n", | |
| " 'y2': array(-2.),\n", | |
| " 'p3_simplex__': array([0.]),\n", | |
| " 'y3': array(-1.)}" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ip = m.initial_point()\n", | |
| "ip.pop(\"mix_comp3\", None)\n", | |
| "ip.pop(\"mix_comp2\", None)\n", | |
| "ip.pop(\"mix_comp1\", None)\n", | |
| "ip" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "8a2f1de8", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array(-57.23329681)" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "f = m.compile_fn(logp_graph)\n", | |
| "f(ip)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "2c8ba2ac", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "193" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(f.f.maker.fgraph.apply_nodes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "c5ef7a5c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array(-57.23329681)" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "with pm.Model() as m_ref:\n", | |
| " p1 = pm.Dirichlet(\"p1\", [1, 1])\n", | |
| " y1 = pm.NormalMixture(\"y1\", p1, [-1, 1])\n", | |
| " \n", | |
| " p2 = pm.Dirichlet(\"p2\", [1, 1])\n", | |
| " y2 = pm.NormalMixture(\"y2\", p2, [-2, 2])\n", | |
| " \n", | |
| " p3 = pm.Dirichlet(\"p3\", [1, 1])\n", | |
| " y3 = pm.NormalMixture(\"y3\", p3, [y1, y2])\n", | |
| " \n", | |
| " pm.Normal(\"llike\", y3, 1, observed=9)\n", | |
| " \n", | |
| "m_ref.compile_logp()(ip)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "98072f1c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "hide_input": false, | |
| "kernelspec": { | |
| "display_name": "pymc", | |
| "language": "python", | |
| "name": "pymc" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.4" | |
| }, | |
| "toc": { | |
| "base_numbering": 1, | |
| "nav_menu": {}, | |
| "number_sections": true, | |
| "sideBar": true, | |
| "skip_h1_title": false, | |
| "title_cell": "Table of Contents", | |
| "title_sidebar": "Contents", | |
| "toc_cell": false, | |
| "toc_position": {}, | |
| "toc_section_display": true, | |
| "toc_window_display": false | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment