Last active
November 18, 2023 22:26
-
-
Save torfjelde/cc5c41e97eb4c97e22a19b8440f6d506 to your computer and use it in GitHub Desktop.
Simple example of using NUTS with the new iterator interface in AbstractMCMC.jl available using Turing.jl > 0.15.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> using Turing, Random | |
julia> @model function gdemo(xs) | |
# Assumptions | |
σ² ~ InverseGamma(2, 3) | |
μ ~ Normal(0, √σ²) | |
# Observations | |
for i = 1:length(xs) | |
xs[i] ~ Normal(μ, √σ²) | |
end | |
end | |
gdemo (generic function with 2 methods) | |
julia> # Set up. | |
xs = randn(100); | |
julia> model = gdemo(xs); | |
julia> # Sampler. | |
alg = NUTS(0.65); | |
julia> kwargs = (nadapts=50,); | |
julia> num_samples = 100; | |
julia> ### The following two methods are equivalent ### | |
## Using `sample` ## | |
rng = MersenneTwister(42); | |
julia> chain = sample(rng, model, alg, num_samples; kwargs...) | |
┌ Info: Found initial step size | |
└ ϵ = 0.4 | |
Sampling 100%|█████████████████████████████████████████████████████| Time: 0:00:00 | |
Chains MCMC chain (100×14×1 Array{Float64, 3}): | |
Iterations = 51:1:150 | |
Number of chains = 1 | |
Samples per chain = 100 | |
Wall duration = 1.12 seconds | |
Compute duration = 1.12 seconds | |
parameters = σ², μ | |
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size | |
Summary Statistics | |
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec | |
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64 | |
σ² 1.1090 0.1728 0.0591 10.8467 51.5563 1.1136 9.6587 | |
μ -0.1753 0.1030 0.0126 66.7404 78.3393 0.9940 59.4304 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
σ² 0.8499 0.9980 1.1052 1.2155 1.4554 | |
μ -0.3526 -0.2485 -0.1774 -0.1074 0.0239 | |
julia> ## Using the iterator-interface ## | |
rng = MersenneTwister(42); | |
julia> spl = DynamicPPL.Sampler(alg); | |
julia> nadapts = 50; | |
julia> # Create an iterator we can just step through. | |
it = AbstractMCMC.Stepper(rng, model, spl, kwargs); | |
julia> # Initial sample and state. | |
transition, state = iterate(it); | |
┌ Info: Found initial step size | |
└ ϵ = 0.4 | |
julia> # Simple container to hold the samples. | |
transitions = []; | |
julia> # Simple condition that says we only want `num_samples` samples. | |
condition(spls) = length(spls) < num_samples | |
condition (generic function with 1 method) | |
julia> # Sample until `condition` is no longer satisfied | |
while condition(transitions) | |
# For an iterator we pass in the previous `state` as the second argument | |
transition, state = iterate(it, state) | |
# Save `transition` if we're not adapting anymore | |
if state.i > nadapts | |
push!(transitions, transition) | |
end | |
end | |
julia> length(transitions), state.i, state.i == length(transitions) + nadapts | |
(100, 150, true) | |
julia> # Finally, if you want to convert the vector of `transitions` into a | |
# `MCMCChains.Chains` like is typically done: | |
chain = AbstractMCMC.bundle_samples( | |
map(identity, transitions), # trick to concretize the eltype of `transitions` | |
model, | |
spl, | |
state, | |
MCMCChains.Chains | |
) | |
Chains MCMC chain (100×14×1 Array{Float64, 3}): | |
Iterations = 1:1:100 | |
Number of chains = 1 | |
Samples per chain = 100 | |
parameters = σ², μ | |
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size | |
Summary Statistics | |
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec | |
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Missing | |
σ² 1.1090 0.1728 0.0591 10.8467 51.5563 1.1136 missing | |
μ -0.1753 0.1030 0.0126 66.7404 78.3393 0.9940 missing | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
σ² 0.8499 0.9980 1.1052 1.2155 1.4554 | |
μ -0.3526 -0.2485 -0.1774 -0.1074 0.0239 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment