Created
December 13, 2025 19:47
-
-
Save adrn/ec1cb7d9f46d162dd63200c7c386a811 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "02d68a5a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from collections import defaultdict\n", | |
| "import astropy.table as at\n", | |
| "import numpy as np\n", | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import numpyro.distributions as dist\n", | |
| "from numpyro_ext import distributions as distx\n", | |
| "\n", | |
| "%matplotlib inline\n", | |
| "jax.config.update(\"jax_enable_x64\", True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "a5a2bca6", | |
| "metadata": {}, | |
| "source": [ | |
| "## Setup\n", | |
| "\n", | |
| "We observe radial velocity data $v$ for a set of $N$ sources at distinct observation times $t$. Let's assume that the sources observed are the luminous star in a bunch of binary systems where the companions are nonluminous (i.e., SB1 systems). For simplicity, we enforce that the orbits are circular, the eccentricities are zero, the systemic velocities are zero, and the angles are all zero. For this simplified model, the radial velocity data for a given source is generated by the model: \n", | |
| "$$\n", | |
| "v(t) = K \\, \\cos(2\\pi t / P)\n", | |
| "$$\n", | |
| "where $K$ is the semi-amplitude and $P$ is the period of the orbit. The data are also noisy, with varied quality, but all Gaussian noise. Let's simulate some data.\n", | |
| "\n", | |
| "First, we generate the \"true\" parameters for each star:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "fcbc872d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "n_stars = 1024\n", | |
| "\n", | |
| "rng = np.random.default_rng(seed=12345)\n", | |
| "\n", | |
| "# number of observations per star\n", | |
| "n_times = np.exp(rng.uniform(np.log(3), np.log(100), size=n_stars)).astype(int)\n", | |
| "time_baseline = 100.0\n", | |
| "\n", | |
| "# period values per star\n", | |
| "max_period = 500.0\n", | |
| "true_periods = np.exp(rng.uniform(np.log(1.0), np.log(max_period), n_stars))\n", | |
| "\n", | |
| "# amplitudes\n", | |
| "true_amps = rng.choice([-1, 1], size=n_stars) * np.exp(\n", | |
| " rng.uniform(-1, 1.0, size=n_stars)\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "54845ba8", | |
| "metadata": {}, | |
| "source": [ | |
| "We now have to define what we mean by \"binary fraction\" to get the true binary fraction of the sample. This amounts to some selection in the physical parameters. In the case of real systems, this might be a selection on periods (to get \"close-binaries\") and a selection on companion mass (i.e. to distinguish stellar-mass companions vs. others). Here, we will just define a selection on the period and semi-amplitude:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "4478d1b1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "True binary fraction: 0.341\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "binary_max_period = time_baseline\n", | |
| "binary_min_amp = 1.0\n", | |
| "true_binary_frac = (\n", | |
| " np.sum((true_periods < binary_max_period) & (np.abs(true_amps) > binary_min_amp))\n", | |
| " / n_stars\n", | |
| ")\n", | |
| "print(f\"True binary fraction: {true_binary_frac:.3f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "8c014dc6", | |
| "metadata": {}, | |
| "source": [ | |
| "Now we can generate the radial velocity data for each star, including noise:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "cf0f9662", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "rv_data = defaultdict(list)\n", | |
| "for i in range(n_stars):\n", | |
| " # Generate random times for each star\n", | |
| " rv_data[\"time\"].append(np.sort(rng.uniform(0, time_baseline, size=n_times[i])))\n", | |
| "\n", | |
| " # Generate random radial velocities for each star\n", | |
| " rv_data[\"rv_err\"].append(np.exp(rng.uniform(-1, 1.0, size=n_times[i])))\n", | |
| " rv_data[\"rv\"].append(\n", | |
| " true_amps[i] * np.cos(2 * np.pi * rv_data[\"time\"][i] / true_periods[i])\n", | |
| " + rng.normal(scale=rv_data[\"rv_err\"][i], size=n_times[i])\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "85e2cfb7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 864x576 with 16 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "image/png": { | |
| "height": 584, | |
| "width": 872 | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "fig, axes = plt.subplots(\n", | |
| " 4, 4, figsize=(12, 8), sharex=True, sharey=True, layout=\"constrained\"\n", | |
| ")\n", | |
| "\n", | |
| "for i, ax in enumerate(axes.flatten()):\n", | |
| " ax.errorbar(\n", | |
| " rv_data[\"time\"][i], rv_data[\"rv\"][i], yerr=rv_data[\"rv_err\"][i], fmt=\"o\"\n", | |
| " )\n", | |
| "\n", | |
| "for ax in axes[:, 0]:\n", | |
| " ax.set_ylabel(\"radial velocity\")\n", | |
| "for ax in axes[-1, :]:\n", | |
| " ax.set_xlabel(\"time\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "9f1f4e21", | |
| "metadata": {}, | |
| "source": [ | |
| "## The Joker-style rejection sampler\n", | |
| "\n", | |
| "We'll implement a rejection sampler that:\n", | |
| "1. Samples periods from a prior\n", | |
| "2. Analytically marginalizes over the amplitude (linear parameter)\n", | |
| "3. Uses rejection sampling based on the marginalized likelihood\n", | |
| "4. Samples amplitudes from their conditional distribution for accepted samples\n", | |
| "\n", | |
| "The key trick: for a linear parameter $K$, the likelihood is Gaussian in $K$, so we can analytically compute the marginal likelihood and then sample $K$ from its posterior distribution." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "f09e8839", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Max observations: 99\n", | |
| "Padded shapes: t=(1024, 99), rv=(1024, 99), rv_err=(1024, 99)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Pad all data to the same length (max number of observations)\n", | |
| "# This is a hack so that JIT doesn't recompile for each star of different n_obs...\n", | |
| "max_n_obs = np.max(n_times)\n", | |
| "print(f\"Max observations: {max_n_obs}\")\n", | |
| "\n", | |
| "\n", | |
| "def pad_to_length(arr, length, pad_value):\n", | |
| " \"\"\"Pad array to specified length.\"\"\"\n", | |
| " padded = np.full(length, pad_value)\n", | |
| " padded[: len(arr)] = arr\n", | |
| " return padded\n", | |
| "\n", | |
| "\n", | |
| "# Create padded arrays for all stars\n", | |
| "t_padded = jnp.array([pad_to_length(t, max_n_obs, 0.0) for t in rv_data[\"time\"]])\n", | |
| "rv_padded = jnp.array([pad_to_length(rv, max_n_obs, 0.0) for rv in rv_data[\"rv\"]])\n", | |
| "rv_err_padded = jnp.array(\n", | |
| " [pad_to_length(err, max_n_obs, jnp.inf) for err in rv_data[\"rv_err\"]]\n", | |
| ")\n", | |
| "\n", | |
| "print(\n", | |
| " f\"Padded shapes: t={t_padded.shape}, rv={rv_padded.shape}, rv_err={rv_err_padded.shape}\"\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "1bd0f97f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "linear_prior_scale = 100.0\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "def compute_all_log_likelihoods(periods, t_all, rv_all, rv_err_all):\n", | |
| " \"\"\"Compute log likelihoods for all periods and all stars using scan.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " periods : array (n_periods,)\n", | |
| " t_all : array (n_stars, max_n_obs)\n", | |
| " rv_all : array (n_stars, max_n_obs)\n", | |
| " rv_err_all : array (n_stars, max_n_obs)\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " lnL : array (n_stars, n_periods)\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " def compute_star_likelihoods(t, rv, rv_err):\n", | |
| " \"\"\"Compute likelihoods for one star, all periods.\"\"\"\n", | |
| " # Create mask for valid (non-padded) points\n", | |
| " # Padded points have rv_err = inf\n", | |
| " valid_mask = jnp.isfinite(rv_err)\n", | |
| "\n", | |
| " # Replace padded points with valid values (won't affect likelihood)\n", | |
| " # Set padded values to 0 error so they contribute nothing\n", | |
| " rv_err_safe = jnp.where(valid_mask, rv_err, 1e10)\n", | |
| "\n", | |
| " def single_period_likelihood(period):\n", | |
| " design_matrix = jnp.cos(2 * jnp.pi * t / period)[:, None]\n", | |
| " linear_prior = dist.Normal(0.0, 1000.0)\n", | |
| " marginalized = distx.MarginalizedLinear(\n", | |
| " design_matrix,\n", | |
| " linear_prior,\n", | |
| " dist.Normal(jnp.zeros(len(rv)), rv_err_safe),\n", | |
| " )\n", | |
| " return marginalized.log_prob(rv)\n", | |
| "\n", | |
| " return jax.vmap(single_period_likelihood)(periods)\n", | |
| "\n", | |
| " # use scan to loop over stars without vmapping (because memory)\n", | |
| " def scan_fn(carry, x):\n", | |
| " t, rv, rv_err = x\n", | |
| " lnL = compute_star_likelihoods(t, rv, rv_err)\n", | |
| " return carry, lnL\n", | |
| "\n", | |
| " _, lnL_all = jax.lax.scan(scan_fn, None, (t_all, rv_all, rv_err_all))\n", | |
| "\n", | |
| " return lnL_all" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "e8eae702", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@jax.jit\n", | |
| "def rejection_step(key, lnL_all):\n", | |
| " \"\"\"Perform rejection sampling for all stars\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " key : PRNGKey\n", | |
| " lnL_all : array (n_stars, n_periods)\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " weights : array (n_stars, n_periods)\n", | |
| " accepted : array (n_stars, n_periods) boolean\n", | |
| " \"\"\"\n", | |
| " # Normalize weights per star\n", | |
| " lnL_max = jnp.max(lnL_all, axis=1, keepdims=True)\n", | |
| " weights = jnp.exp(lnL_all - lnL_max)\n", | |
| "\n", | |
| " # Rejection sampling\n", | |
| " uniform_draws = jax.random.uniform(key, shape=lnL_all.shape)\n", | |
| " accepted = uniform_draws < weights\n", | |
| "\n", | |
| " return weights, accepted" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "feeabace", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@jax.jit\n", | |
| "def sample_K_given_period(key, periods, t, rv, rv_err):\n", | |
| " \"\"\"Sample amplitude K given period from the conditional posterior.\"\"\"\n", | |
| "\n", | |
| " def sample_single_K(period, subkey):\n", | |
| " # Design matrix for this period\n", | |
| " design_matrix = jnp.cos(2 * jnp.pi * t / period)[:, None]\n", | |
| "\n", | |
| " # Linear prior on K\n", | |
| " linear_prior = dist.Normal(0.0, 1000.0)\n", | |
| "\n", | |
| " # Create the marginalized distribution\n", | |
| " marginalized = distx.MarginalizedLinear(\n", | |
| " design_matrix,\n", | |
| " linear_prior,\n", | |
| " dist.Normal(jnp.zeros(len(rv)), rv_err),\n", | |
| " )\n", | |
| "\n", | |
| " # Sample K from the conditional posterior\n", | |
| " K_sample = marginalized.conditional(rv).sample(subkey)\n", | |
| "\n", | |
| " return K_sample[0] # Return scalar (remove the [0] dimension)\n", | |
| "\n", | |
| " # Generate keys for each period\n", | |
| " keys = jax.random.split(key, len(periods))\n", | |
| "\n", | |
| " # Sample K for each period\n", | |
| " K_samples = jax.vmap(sample_single_K)(periods, keys)\n", | |
| "\n", | |
| " return K_samples" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "7ab672de", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def run_batched_rejection_sampler(\n", | |
| " key,\n", | |
| " t_padded,\n", | |
| " rv_padded,\n", | |
| " rv_err_padded,\n", | |
| " n_prior_samples=100_000,\n", | |
| " n_posterior_samples=256,\n", | |
| " period_min=1.0,\n", | |
| " period_max=500.0,\n", | |
| "):\n", | |
| " \"\"\"\n", | |
| " Run rejection sampler on all stars in a single batched computation.\n", | |
| " \"\"\"\n", | |
| " key1, key2, key3 = jax.random.split(key, 3)\n", | |
| "\n", | |
| " # Generate shared period samples\n", | |
| " lnP_samples = jax.random.uniform(\n", | |
| " key1,\n", | |
| " shape=(n_prior_samples,),\n", | |
| " minval=jnp.log(period_min),\n", | |
| " maxval=jnp.log(period_max),\n", | |
| " )\n", | |
| " P_samples = jnp.exp(lnP_samples)\n", | |
| "\n", | |
| " # Compute all likelihoods at once (n_stars, n_periods)\n", | |
| " print(\"Computing likelihoods for all stars...\")\n", | |
| " lnL_all_stars = compute_all_log_likelihoods(\n", | |
| " P_samples, t_padded, rv_padded, rv_err_padded\n", | |
| " )\n", | |
| "\n", | |
| " # Rejection step for all stars\n", | |
| " print(\"Running rejection step...\")\n", | |
| " weights, accepted_mask = rejection_step(key2, lnL_all_stars)\n", | |
| "\n", | |
| " # Extract accepted samples for each star\n", | |
| " print(\"Extracting accepted samples...\")\n", | |
| " n_stars = t_padded.shape[0]\n", | |
| " results = {}\n", | |
| "\n", | |
| " keys = jax.random.split(key3, n_stars)\n", | |
| " for n in range(n_stars):\n", | |
| " k1, k2 = jax.random.split(keys[n], 2)\n", | |
| " tmp = P_samples[accepted_mask[n]]\n", | |
| " accepted_periods_n = jax.random.choice(\n", | |
| " k1,\n", | |
| " tmp,\n", | |
| " shape=(jnp.minimum(tmp.shape[0], n_posterior_samples),),\n", | |
| " replace=False,\n", | |
| " )\n", | |
| "\n", | |
| " _mask = jnp.isfinite(rv_err_padded[n])\n", | |
| " t_orig = t_padded[n][_mask]\n", | |
| " rv_orig = rv_padded[n][_mask]\n", | |
| " rv_err_orig = rv_err_padded[n][_mask]\n", | |
| "\n", | |
| " amplitudes = sample_K_given_period(\n", | |
| " k2,\n", | |
| " accepted_periods_n,\n", | |
| " t_orig,\n", | |
| " rv_orig,\n", | |
| " rv_err_orig,\n", | |
| " )\n", | |
| "\n", | |
| " results[n] = {\n", | |
| " \"periods\": accepted_periods_n,\n", | |
| " \"amplitudes\": amplitudes,\n", | |
| " }\n", | |
| "\n", | |
| " return results" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "b7e2cb03", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Computing likelihoods for all stars...\n", | |
| "Running rejection step...\n", | |
| "Extracting accepted samples...\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "results = run_batched_rejection_sampler(\n", | |
| " jax.random.PRNGKey(74),\n", | |
| " t_padded,\n", | |
| " rv_padded,\n", | |
| " rv_err_padded,\n", | |
| " n_prior_samples=1_000_000,\n", | |
| " n_posterior_samples=256,\n", | |
| " period_min=1.0,\n", | |
| " period_max=max_period,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "0018d3fe", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "np.int64(69)" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "n_posterior_samples_returned = np.array(\n", | |
| " [results[n][\"periods\"].shape[0] for n in range(n_stars)]\n", | |
| ")\n", | |
| "n_posterior_samples_returned.min()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "id": "6e9e9555", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment