Last active
February 20, 2024 21:41
-
-
Save ahwillia/c35bfb2a09add7b2e5745ca7b424a3b1 to your computer and use it in GitHub Desktop.
Elliptical Slice Sampler in JAX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
NOTE: This code has not been rigorously tested. | |
""" | |
import matplotlib.pyplot as plt | |
import jax.numpy as jnp | |
import jax | |
from tqdm import trange | |
def elliptical_slice_update(x, log_density, sigmas, key): | |
""" | |
Runs one update of Elliptical slice sampling with | |
respect to a mean-zero, diagonal Gaussian prior. | |
Parameters | |
---------- | |
x : array-like | |
Current state of the MCMC chain, corresponds | |
to parameters we are sampling. | |
log_density : Callable | |
Function that computes log density (up to an | |
additive constant) that we'd like to sample from | |
after multiplication with the prior. | |
sigmas : array-like | |
Standard deviations for each element of x. This | |
specifies a prior distribution over x, which is | |
mean-zero and diagonal covariance with | |
elements given by (sigmas ** 2) | |
key : jax.random.PRNGKey | |
Random number seed, used to generate proposal. | |
Returns | |
------- | |
x_next : array-like | |
Next state of the MCMC chain. | |
""" | |
assert x.shape == sigmas.shape | |
k1, k2, k3, k4 = jax.random.split(key, num=4) | |
nu = jax.random.normal(k1, shape=x.shape) * sigmas | |
thres = log_density(x) + jnp.log(jax.random.uniform(k2)) | |
# Initial proposal | |
theta = jax.random.uniform(k3, minval=0., maxval=(2 * jnp.pi)) | |
init_loop_state = ( | |
x * jnp.cos(theta) + nu * jnp.sin(theta), | |
theta - 2 * jnp.pi, | |
theta, | |
theta, | |
k4 | |
) | |
def while_cond_fun(loop_state): | |
x_proposed, _, _, _, _ = loop_state | |
return log_density(x_proposed) <= thres | |
def true_func(loop_state): | |
_, _, theta_max, theta, _ = loop_state | |
return theta, theta_max | |
def false_func(loop_state): | |
_, theta_min, _, theta, _ = loop_state | |
return theta_min, theta | |
def while_body_fun(loop_state): | |
_, _, _, theta0, key0 = loop_state | |
# Reduce brackets and draw a new theta. | |
theta_min1, theta_max1 = jax.lax.cond( | |
theta0 < 0, | |
true_func, | |
false_func, | |
loop_state, | |
) | |
theta1 = jax.random.uniform( | |
key0, | |
minval=theta_min1, | |
maxval=theta_max1 | |
) | |
# Propose new value for x. | |
x_proposed = x * jnp.cos(theta1) + nu * jnp.sin(theta1) | |
# Update random key for next iteration. | |
key1 = jax.random.split(key0)[0] | |
return x_proposed, theta_min1, theta_max1, theta1, key1 | |
final_loop_state = jax.lax.while_loop( | |
while_cond_fun, while_body_fun, init_loop_state | |
) | |
return final_loop_state[0] | |
if __name__ == "__main__": | |
x = jnp.zeros(2) | |
x_samples = [x] | |
key = jax.random.PRNGKey(111) | |
num_samples = 1000 | |
def log_density(u): | |
return jnp.log(jnp.max(jnp.abs(u)) < 1) | |
for i in trange(num_samples): | |
_, key = jax.random.split(key) | |
x = elliptical_slice_update( | |
x, | |
log_density, | |
jnp.array([0.25, 4.0]), | |
key | |
) | |
x_samples.append(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment