Skip to content

Instantly share code, notes, and snippets.

@smsharma
Created June 22, 2023 07:51
Show Gist options
  • Save smsharma/d853c86f3954a893727f91083598f14c to your computer and use it in GitHub Desktop.
Save smsharma/d853c86f3954a893727f91083598f14c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"numpyro.set_host_device_count(4)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Make data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x15f67a0a0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"n_xy = 64 # Number of grid points in each direction\n",
"\n",
"# Deterministic jax simulator that takes in an array of positions x, y and puts down point sources at those positions\n",
"# on a grid (-1, 1) x (-1, 1)\n",
"@jax.jit\n",
"def simulate(x, y, s):\n",
" \"\"\" Simulate a map with point sources at positions x, y with fluxes s.\n",
" \"\"\"\n",
" psf_std = 0.1 # PSF width\n",
" mu_bkg = 1. # Background level\n",
"\n",
" grid = jnp.linspace(-1, 1, n_xy)\n",
" xx, yy = jnp.meshgrid(grid, grid)\n",
" mu = jnp.zeros_like(xx)\n",
" \n",
" for i in range(len(x)):\n",
" mu += s[i] * jnp.exp(-((xx - x[i])**2 + (yy - y[i])**2) / psf_std ** 2)\n",
" return mu + mu_bkg\n",
"\n",
"# Simulate data\n",
"mu, sigma = 10, 2\n",
"N = 20\n",
"\n",
"\n",
"x, y = jax.random.uniform(jax.random.PRNGKey(0), (2, N), minval=-1, maxval=1) # Random positions\n",
"s = jax.random.normal(jax.random.PRNGKey(42), (N,)) * sigma + mu # Fluxes drawn from Gaussian with POI mu, sigma\n",
"\n",
"# Simulate map and Poisson fluctuate\n",
"mu_map = simulate(x, y, s)\n",
"data = jax.random.poisson(key=jax.random.PRNGKey(42), lam=mu_map, shape=(n_xy, n_xy))\n",
"\n",
"# Plot\n",
"plt.imshow(data, origin='lower')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## NumPyro model for fixed $N$"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def model(data):\n",
"\n",
" # Parameters of interest\n",
" mu = numpyro.sample('mu', dist.Uniform(3, 20))\n",
" sigma = numpyro.sample('sigma', dist.Uniform(0.1, 4.))\n",
"\n",
" x_ary, y_ary, s_ary = [], [], []\n",
"\n",
" # Loop over point sources, positions and fluxes are parameters\n",
" for i in range(N):\n",
"\n",
" x = numpyro.sample('x_{}'.format(i), dist.Uniform(-1, 1))\n",
" y = numpyro.sample('y_{}'.format(i), dist.Uniform(-1, 1))\n",
" s = numpyro.sample('s_{}'.format(i), dist.LeftTruncatedDistribution(dist.Normal(mu, sigma), low=0.))\n",
"\n",
" x_ary.append(x)\n",
" y_ary.append(y)\n",
" s_ary.append(s)\n",
"\n",
" x_ary = jnp.array(x_ary)\n",
" y_ary = jnp.array(y_ary)\n",
" s_ary = jnp.array(s_ary)\n",
"\n",
" # Sort by brightness (label switching)\n",
" idx_sort = jnp.argsort(s_ary)\n",
" \n",
" x_ary = x_ary[idx_sort]\n",
" y_ary = y_ary[idx_sort]\n",
" s_ary = s_ary[idx_sort]\n",
" \n",
" # Deterministic map sim\n",
" mu_map = simulate(x_ary, y_ary, s_ary)\n",
"\n",
" # Poisson likelihood\n",
" return numpyro.sample('obs', dist.Poisson(mu_map), obs=data)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2500/2500 [01:52<00:00, 22.15it/s, 127 steps of size 2.69e-02. acc. prob=0.90] \n"
]
}
],
"source": [
"mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=2000, num_chains=1)\n",
"mcmc.run(jax.random.PRNGKey(0), data)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" mean std median 5.0% 95.0% n_eff r_hat\n",
" mu 9.52 0.75 9.55 8.36 10.76 64.69 1.07\n",
" s_0 7.51 2.49 8.24 1.87 10.07 39.77 1.14\n",
" s_1 8.02 1.76 8.52 5.16 10.21 9.03 1.17\n",
" s_10 17.34 0.81 17.33 16.04 18.71 2013.46 1.00\n",
" s_11 11.61 0.89 11.59 10.02 12.89 1319.41 1.00\n",
" s_12 9.27 0.60 9.28 8.27 10.26 2175.57 1.00\n",
" s_13 7.08 3.62 6.79 2.21 13.46 26.22 1.15\n",
" s_14 10.11 1.81 10.53 8.23 12.70 27.99 1.10\n",
" s_15 6.84 1.89 7.26 3.42 9.48 9.43 1.06\n",
" s_16 9.06 0.62 9.05 8.04 10.03 3643.52 1.00\n",
" s_17 8.25 1.19 8.25 6.43 10.09 291.92 1.01\n",
" s_18 9.71 1.36 9.83 7.99 11.79 82.62 1.02\n",
" s_19 7.67 1.61 7.90 5.87 10.27 10.18 1.32\n",
" s_2 10.53 0.71 10.50 9.29 11.64 2131.21 1.00\n",
" s_3 8.72 2.42 9.64 4.14 11.45 22.05 1.02\n",
" s_4 8.47 2.26 8.81 5.27 13.33 27.34 1.04\n",
" s_5 9.56 0.65 9.56 8.53 10.67 2433.03 1.00\n",
" s_6 7.57 0.75 7.56 6.32 8.71 796.15 1.00\n",
" s_7 10.05 0.63 10.04 8.92 11.00 4038.04 1.00\n",
" s_8 14.72 0.96 14.69 13.14 16.23 1200.87 1.01\n",
" s_9 8.70 2.63 8.96 4.57 13.44 27.20 1.03\n",
" sigma 3.11 0.48 3.12 2.41 3.96 20.73 1.14\n",
" x_0 0.46 0.19 0.37 0.35 0.82 8.41 1.26\n",
" x_1 0.52 0.07 0.55 0.36 0.57 5.87 1.22\n",
" x_10 -0.90 0.00 -0.90 -0.90 -0.89 2704.19 1.00\n",
" x_11 -0.58 0.01 -0.58 -0.59 -0.57 1361.44 1.00\n",
" x_12 0.18 0.01 0.18 0.17 0.19 2660.53 1.00\n",
" x_13 0.55 0.19 0.57 0.40 0.79 18.40 1.06\n",
" x_14 0.13 0.01 0.13 0.12 0.14 821.34 1.00\n",
" x_15 0.18 0.22 0.05 0.03 0.56 3.84 1.45\n",
" x_16 -0.22 0.01 -0.22 -0.23 -0.21 2513.32 1.00\n",
" x_17 -0.22 0.01 -0.22 -0.24 -0.21 1819.21 1.00\n",
" x_18 -0.30 0.01 -0.30 -0.31 -0.28 1015.93 1.00\n",
" x_19 0.54 0.40 0.88 0.03 0.90 2.63 2.74\n",
" x_2 -0.01 0.01 -0.01 -0.02 -0.00 2737.65 1.00\n",
" x_3 0.78 0.03 0.78 0.75 0.81 159.30 1.00\n",
" x_4 0.41 0.06 0.41 0.34 0.46 15.45 1.02\n",
" x_5 -0.56 0.01 -0.56 -0.57 -0.55 4061.68 1.00\n",
" x_6 0.97 0.01 0.97 0.95 0.99 634.05 1.00\n",
" x_7 -0.80 0.00 -0.80 -0.81 -0.80 2136.28 1.00\n",
" x_8 -0.74 0.01 -0.74 -0.75 -0.73 1489.21 1.00\n",
" x_9 0.51 0.08 0.55 0.41 0.60 3.16 1.98\n",
" y_0 -0.07 0.22 0.10 -0.38 0.12 3.37 2.11\n",
" y_1 0.23 0.05 0.26 0.11 0.27 6.04 1.21\n",
" y_10 0.25 0.00 0.25 0.24 0.25 3522.02 1.00\n",
" y_11 -0.45 0.01 -0.45 -0.46 -0.44 2208.75 1.00\n",
" y_12 -0.61 0.01 -0.61 -0.62 -0.60 3457.60 1.00\n",
" y_13 -0.30 0.22 -0.37 -0.46 -0.20 23.03 1.01\n",
" y_14 -0.12 0.01 -0.12 -0.14 -0.11 562.40 1.00\n",
" y_15 0.15 0.06 0.12 0.10 0.26 3.38 1.58\n",
" y_16 -0.87 0.01 -0.87 -0.88 -0.86 2446.15 1.00\n",
" y_17 0.59 0.01 0.59 0.57 0.61 1284.09 1.00\n",
" y_18 0.72 0.01 0.72 0.70 0.73 809.63 1.00\n",
" y_19 0.58 0.45 0.96 0.09 0.99 2.80 2.79\n",
" y_2 0.93 0.01 0.93 0.92 0.94 3052.27 1.00\n",
" y_3 -0.31 0.01 -0.31 -0.32 -0.29 319.35 1.00\n",
" y_4 -0.14 0.17 -0.23 -0.28 0.13 5.32 1.42\n",
" y_5 0.25 0.01 0.25 0.24 0.26 2794.06 1.00\n",
" y_6 -0.53 0.01 -0.53 -0.55 -0.52 1148.03 1.00\n",
" y_7 0.76 0.01 0.76 0.75 0.77 2265.14 1.00\n",
" y_8 -0.49 0.00 -0.49 -0.50 -0.49 2578.49 1.00\n",
" y_9 -0.32 0.07 -0.36 -0.40 -0.22 2.74 2.70\n",
"\n",
"Number of divergences: 0\n"
]
}
],
"source": [
"mcmc.print_summary()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Posterior predictive maps"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Predictive distribution\n",
"predictive = Predictive(model, mcmc.get_samples())\n",
"samples_predictive = predictive(jax.random.PRNGKey(321), None)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment