# implement bursted sampling, but unbirsted warmup using TFP adaptation
# use greta to get log prob for a knotty model
library(greta)
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
x <- normal(0, c(0.1, 1, 10, 100))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
m <- model(x)
f <- m$dag$generate_log_prob_function()
log_prob <- function(free_state) {
results <- f(free_state)
results$adjusted
}
# set up sampling scheme
n_chains <- 4
n_samples <- 1000
n_warmup <- 1000
# set initial values
dim <- prod(dim(x))
init_raw <- array(0, c(n_chains, dim))
init <- greta:::fl(init_raw)
library(tensorflow)
tfp <- greta:::tfp
# create a kernel for adaptive HMC using SNAPER criterion (adapts leapfrog
# steps, step size, and mass matrix (reduced rank)). If used for warmup, n_adapt
# must be set to the number of warmup steps
## why is n_adapt set to 0?
create_kernel <- function(log_prob, n_adapt = 0) {
kernel_base <- tfp$experimental$mcmc$SNAPERHamiltonianMonteCarlo(
target_log_prob_fn = log_prob,
step_size = 1,
num_adaptation_steps = as.integer(n_adapt))
kernel <- tfp$mcmc$DualAveragingStepSizeAdaptation(
inner_kernel = kernel_base,
num_adaptation_steps = as.integer(n_adapt))
kernel
}
# get a TensorSpec matching a tensor
as_tensorspec <- function(tensor) {
tf$TensorSpec(
shape = tensor$get_shape(),
dtype = tensor$dtype
)
}
# get the final model parameter state from a chain as returned in the all_states
# object from tfp$mcmc$sample_chain
get_last_state <- function(all_states) {
n_iter <- dim(all_states)[1]
tf$gather(all_states, n_iter - 1L, 0L)
}
# find out if MCMC steps had non-finite acceptance probabilities
bad_steps <- function(kernel_results) {
log_accept_ratios <- recursive_get_log_accept_ratio(kernel_results)
!is.finite(log_accept_ratios)
}
# recursively extract the log accaptance ratio from the MCMC kernel
recursive_get_log_accept_ratio <- function(kernel_results) {
nm <- names(kernel_results)
if("log_accept_ratio" %in% nm) {
log_accept_ratios <- kernel_results$log_accept_ratio
} else if ("inner_results" %in% nm) {
log_accept_ratios <- recursive_get_log_accept_ratio(
kernel_results$inner_results
)
} else {
stop("non-standard kernel structure")
}
as.array(log_accept_ratios)
}
# given an MCMC kernel `kernel` and initial model parameter state `init`, adapt
# the kernel tuning parameters whilst simultaneously burning-in the model
# parameter state. Return both the finalised kernel tuning parameters and the
# burned-in model parameter state
warm_up_sampler <- function(kernel, init) {
# get the predetermined adaptation period of the kernel
n_adapt <- kernel$num_adaptation_steps
# make the uncompiled function (with curried arguments)
warmup_raw <- function() {
tfp$mcmc$sample_chain(
num_results = n_adapt,
current_state = init,
kernel = kernel,
return_final_kernel_results = TRUE,
trace_fn = function(current_state, kernel_results) {
kernel_results$step #kernel_results
}
)
}
# compile it into a concrete function
warmup <- tf_function(warmup_raw)
# execute it
result <- warmup()
# return the last (burned-in) state of the model parameters and the final
# (tuned) kernel parameters
list(
kernel = kernel,
kernel_results = result$final_kernel_results,
current_state = get_last_state(result$all_states)
)
}
# given a warmed up sampler object, return a compiled TF function that generates
# a new burst of samples from samples from it
make_sampler_function <- function(warm_sampler) {
# make the uncompiled function (with curried arguments)
sample_raw <- function(current_state, n_samples) {
results <- tfp$mcmc$sample_chain(
# how many iterations
num_results = n_samples,
# where to start from
current_state = current_state,
# kernel
kernel = warm_sampler$kernel,
# tuned sampler settings
previous_kernel_results = warm_sampler$kernel_results,
# what to trace (nothing)
trace_fn = function(current_state, kernel_results) {
# could compute badness here to save memory?
# is.finite(kernel_results$inner_results$inner_results$inner_results$log_accept_ratio)
kernel_results
}
)
# return the parameter states and the kernel results
list(
all_states = results$all_states,
kernel_results = results$trace
)
}
# compile it into a concrete function and return
sample <- tf_function(sample_raw,
list(
as_tensorspec(warm_sampler$current_state),
tf$TensorSpec(shape = c(),
dtype = tf$int32)
))
sample
}
# create the kernel
kernel <- create_kernel(log_prob = log_prob, n_adapt = n_warmup)
# NOTE: could skip setting the adaptation step here, and overwrite it in the
# kernel later
# adapt and warm up
cat("Tuning sampler and burning in samples...")
#> Tuning sampler and burning in samples...
system.time(
warm_results <- warm_up_sampler(kernel, init)
)
#> user system elapsed
#> 2.413 0.059 2.392
cat("...done.")
#> ...done.
# use this to compile the warmed version
sample <- make_sampler_function(warm_results)
# now repeatedly sample the actual results, in bursts
burst_size <- 50L
n_bursts <- n_samples / burst_size
current_state <- warm_results$current_state
trace <- array(NA, dim = c(n_samples, dim(current_state)))
# track numerical rejections
n_bad <- 0
print("Sampling parameters")
#> [1] "Sampling parameters"
for (burst in seq_len(n_bursts)) {
burst_result <- sample(
current_state = current_state,
n_samples = burst_size
)
# trace the MCMC results from this burst
burst_idx <- (burst - 1) * burst_size + seq_len(burst_size)
trace[burst_idx, , ] <- as.array(burst_result$all_states)
# overwrite the current state
current_state <- get_last_state(burst_result$all_states)
# accumulate and report on the badness
new_badness <- sum(bad_steps(burst_result$kernel_results))
n_bad <- n_bad + new_badness
n_evaluations <- burst * burst_size * n_chains
perc_badness <- round(100 * n_bad / n_evaluations)
# report on progress
print(sprintf("burst %i of %i (%i%s bad)",
burst,
n_bursts,
perc_badness,
"%"))
}
#> [1] "burst 1 of 20 (0% bad)"
#> [1] "burst 2 of 20 (0% bad)"
#> [1] "burst 3 of 20 (0% bad)"
#> [1] "burst 4 of 20 (0% bad)"
#> [1] "burst 5 of 20 (0% bad)"
#> [1] "burst 6 of 20 (0% bad)"
#> [1] "burst 7 of 20 (0% bad)"
#> [1] "burst 8 of 20 (0% bad)"
#> [1] "burst 9 of 20 (0% bad)"
#> [1] "burst 10 of 20 (0% bad)"
#> [1] "burst 11 of 20 (0% bad)"
#> [1] "burst 12 of 20 (0% bad)"
#> [1] "burst 13 of 20 (0% bad)"
#> [1] "burst 14 of 20 (0% bad)"
#> [1] "burst 15 of 20 (0% bad)"
#> [1] "burst 16 of 20 (0% bad)"
#> [1] "burst 17 of 20 (0% bad)"
#> [1] "burst 18 of 20 (0% bad)"
#> [1] "burst 19 of 20 (0% bad)"
#> [1] "burst 20 of 20 (0% bad)"
# plot the trace
list <- apply(trace, 2, coda::as.mcmc, simplify = FALSE)
draws <- coda::as.mcmc.list(list)
par(mfrow = c(2, 4))
plot(draws, auto.layout = FALSE)
Created on 2025-03-06 with reprex v2.1.1
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.4.2 (2024-10-31)
#> os macOS Sequoia 15.1
#> system aarch64, darwin20
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Australia/Hobart
#> date 2025-03-06
#> pandoc 3.2.1 @ /opt/homebrew/bin/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> backports 1.5.0 2024-05-23 [1] CRAN (R 4.4.0)
#> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.4.0)
#> callr 3.7.6 2024-03-25 [1] CRAN (R 4.4.0)
#> cli 3.6.4 2025-02-13 [1] CRAN (R 4.4.1)
#> coda 0.19-4.1 2024-01-31 [1] CRAN (R 4.4.0)
#> codetools 0.2-20 2024-03-31 [2] CRAN (R 4.4.2)
#> crayon 1.5.3 2024-06-20 [1] CRAN (R 4.4.0)
#> curl 6.2.0 2025-01-23 [1] CRAN (R 4.4.1)
#> digest 0.6.37 2024-08-19 [1] CRAN (R 4.4.1)
#> evaluate 1.0.1 2024-10-10 [1] CRAN (R 4.4.1)
#> fastmap 1.2.0 2024-05-15 [1] CRAN (R 4.4.0)
#> fs 1.6.5 2024-10-30 [1] CRAN (R 4.4.1)
#> future 1.34.0 2024-07-29 [1] CRAN (R 4.4.0)
#> globals 0.16.3 2024-03-08 [1] CRAN (R 4.4.0)
#> glue 1.8.0 2024-09-30 [1] CRAN (R 4.4.1)
#> greta * 0.5.0.9000 2025-01-24 [1] local
#> hms 1.1.3 2023-03-21 [1] CRAN (R 4.4.0)
#> htmltools 0.5.8.1 2024-04-04 [1] CRAN (R 4.4.0)
#> jsonlite 1.8.9 2024-09-20 [1] CRAN (R 4.4.1)
#> knitr 1.49 2024-11-08 [1] CRAN (R 4.4.1)
#> lattice 0.22-6 2024-03-20 [2] CRAN (R 4.4.2)
#> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.4.0)
#> listenv 0.9.1 2024-01-29 [1] CRAN (R 4.4.0)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.4.0)
#> Matrix 1.7-1 2024-10-18 [2] CRAN (R 4.4.2)
#> parallelly 1.41.0 2024-12-18 [1] CRAN (R 4.4.1)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.4.0)
#> png 0.1-8 2022-11-29 [1] CRAN (R 4.4.0)
#> prettyunits 1.2.0 2023-09-24 [1] CRAN (R 4.4.0)
#> processx 3.8.5 2025-01-08 [1] CRAN (R 4.4.1)
#> progress 1.2.3 2023-12-06 [1] CRAN (R 4.4.0)
#> ps 1.8.1 2024-10-28 [1] CRAN (R 4.4.1)
#> R6 2.6.1 2025-02-15 [1] CRAN (R 4.4.1)
#> Rcpp 1.0.14 2025-01-12 [1] CRAN (R 4.4.1)
#> reprex 2.1.1 2024-07-06 [1] CRAN (R 4.4.0)
#> reticulate 1.40.0 2024-11-15 [1] CRAN (R 4.4.1)
#> rlang 1.1.5 2025-01-17 [1] CRAN (R 4.4.1)
#> rmarkdown 2.29 2024-11-04 [1] CRAN (R 4.4.1)
#> rstudioapi 0.17.1 2024-10-22 [1] CRAN (R 4.4.1)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.4.0)
#> tensorflow * 2.16.0 2024-04-15 [1] CRAN (R 4.4.0)
#> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.4.0)
#> tfruns 1.5.3 2024-04-19 [1] CRAN (R 4.4.0)
#> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.4.0)
#> whisker 0.4.1 2022-12-05 [1] CRAN (R 4.4.0)
#> withr 3.0.2 2024-10-28 [1] CRAN (R 4.4.1)
#> xfun 0.50.5 2025-01-15 [1] Github (yihui/xfun@116d689)
#> xml2 1.3.6 2023-12-04 [1] CRAN (R 4.4.0)
#> yaml 2.3.10 2024-07-26 [1] CRAN (R 4.4.0)
#>
#> [1] /Users/nick/Library/R/arm64/4.4/library
#> [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#>
#> ─ Python configuration ───────────────────────────────────────────────────────
#> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#> version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#> numpy_version: 1.26.4
#> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>
#> NOTE: Python version was forced by use_python() function
#>
#> ──────────────────────────────────────────────────────────────────────────────