Skip to content

Instantly share code, notes, and snippets.

@njtierney
Created March 6, 2025 01:47
Show Gist options
  • Save njtierney/5a2145b910ac6e354107e94d5352bf24 to your computer and use it in GitHub Desktop.
Save njtierney/5a2145b910ac6e354107e94d5352bf24 to your computer and use it in GitHub Desktop.
# implement bursted sampling, but unbirsted warmup using TFP adaptation

# use greta to get log prob for a knotty model
library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
x <- normal(0, c(0.1, 1, 10, 100))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
m <- model(x)

f <- m$dag$generate_log_prob_function()
log_prob <- function(free_state) {
  results <- f(free_state)
  results$adjusted
}

# set up sampling scheme
n_chains <- 4
n_samples <- 1000
n_warmup <- 1000

# set initial values
dim <- prod(dim(x))
init_raw <- array(0, c(n_chains, dim))
init <- greta:::fl(init_raw)

library(tensorflow)
tfp <- greta:::tfp

# create a kernel for adaptive HMC using SNAPER criterion (adapts leapfrog
# steps, step size, and mass matrix (reduced rank)). If used for warmup, n_adapt
# must be set to the number of warmup steps
## why is n_adapt set to 0?
create_kernel <- function(log_prob, n_adapt = 0) {
  
  kernel_base <- tfp$experimental$mcmc$SNAPERHamiltonianMonteCarlo(
    target_log_prob_fn = log_prob,
    step_size = 1,
    num_adaptation_steps = as.integer(n_adapt))
  
  kernel <- tfp$mcmc$DualAveragingStepSizeAdaptation(
    inner_kernel = kernel_base,
    num_adaptation_steps = as.integer(n_adapt))
  
  kernel
}

# get a TensorSpec matching a tensor
as_tensorspec <- function(tensor) {
  tf$TensorSpec(
    shape = tensor$get_shape(),
    dtype = tensor$dtype
  )
}

# get the final model parameter state from a chain as returned in the all_states
# object from tfp$mcmc$sample_chain
get_last_state <- function(all_states) {
  n_iter <- dim(all_states)[1]
  tf$gather(all_states, n_iter - 1L, 0L)
}

# find out if MCMC steps had non-finite acceptance probabilities
bad_steps <- function(kernel_results) {
  log_accept_ratios <- recursive_get_log_accept_ratio(kernel_results)
  !is.finite(log_accept_ratios)
}

# recursively extract the log accaptance ratio from the MCMC kernel
recursive_get_log_accept_ratio <- function(kernel_results) {
  nm <- names(kernel_results)
  if("log_accept_ratio" %in% nm) {
    log_accept_ratios <- kernel_results$log_accept_ratio
  } else if ("inner_results" %in% nm) {
    log_accept_ratios <- recursive_get_log_accept_ratio(
      kernel_results$inner_results
    )
  } else {
    stop("non-standard kernel structure")
  }
  as.array(log_accept_ratios)
}

# given an MCMC kernel `kernel` and initial model parameter state `init`, adapt
# the kernel tuning parameters whilst simultaneously burning-in the model
# parameter state. Return both the finalised kernel tuning parameters and the
# burned-in model parameter state
warm_up_sampler <- function(kernel, init) {
  
  # get the predetermined adaptation period of the kernel
  n_adapt <- kernel$num_adaptation_steps
  
  # make the uncompiled function (with curried arguments)
  warmup_raw <- function() {
    tfp$mcmc$sample_chain(
      num_results = n_adapt,
      current_state = init,
      kernel = kernel,
      return_final_kernel_results = TRUE,
      trace_fn = function(current_state, kernel_results) {
        kernel_results$step #kernel_results
      }
    )
  }
  
  # compile it into a concrete function
  warmup <- tf_function(warmup_raw)
  
  # execute it
  result <- warmup()
  
  # return the last (burned-in) state of the model parameters and the final
  # (tuned) kernel parameters
  list(
    kernel = kernel,
    kernel_results = result$final_kernel_results,
    current_state = get_last_state(result$all_states)
  )
  
}

# given a warmed up sampler object, return a compiled TF function that generates
# a new burst of samples from samples from it
make_sampler_function <- function(warm_sampler) {
  
  # make the uncompiled function (with curried arguments)
  sample_raw <- function(current_state, n_samples) {
    results <- tfp$mcmc$sample_chain(
      # how many iterations
      num_results = n_samples,
      # where to start from
      current_state = current_state,
      # kernel
      kernel = warm_sampler$kernel,
      # tuned sampler settings
      previous_kernel_results = warm_sampler$kernel_results,
      # what to trace (nothing)
      trace_fn = function(current_state, kernel_results) {
        # could compute badness here to save memory?
        # is.finite(kernel_results$inner_results$inner_results$inner_results$log_accept_ratio)
        kernel_results
      }
    )
    # return the parameter states and the kernel results
    list(
      all_states = results$all_states,
      kernel_results = results$trace
    )
  }
  
  # compile it into a concrete function and return
  sample <- tf_function(sample_raw,
                        list(
                          as_tensorspec(warm_sampler$current_state),
                          tf$TensorSpec(shape = c(),
                                        dtype = tf$int32)
                        ))
  
  sample
  
}

# create the kernel
kernel <- create_kernel(log_prob = log_prob, n_adapt = n_warmup)

# NOTE: could skip setting the adaptation step here, and overwrite it in the
# kernel later

# adapt and warm up
cat("Tuning sampler and burning in samples...")
#> Tuning sampler and burning in samples...
system.time(
  warm_results <- warm_up_sampler(kernel, init)
)
#>    user  system elapsed 
#>   2.413   0.059   2.392
cat("...done.")
#> ...done.

# use this to compile the warmed version
sample <- make_sampler_function(warm_results)

# now repeatedly sample the actual results, in bursts
burst_size <- 50L
n_bursts <- n_samples / burst_size

current_state <- warm_results$current_state
trace <- array(NA, dim = c(n_samples, dim(current_state)))

# track numerical rejections
n_bad <- 0

print("Sampling parameters")
#> [1] "Sampling parameters"
for (burst in seq_len(n_bursts)) {
  burst_result <- sample(
    current_state = current_state,
    n_samples = burst_size
  )
  
  # trace the MCMC results from this burst
  burst_idx <- (burst - 1) * burst_size + seq_len(burst_size)
  trace[burst_idx, , ] <- as.array(burst_result$all_states)
  
  # overwrite the current state
  current_state <- get_last_state(burst_result$all_states)
  
  # accumulate and report on the badness
  new_badness <- sum(bad_steps(burst_result$kernel_results))
  n_bad <- n_bad + new_badness
  n_evaluations <- burst * burst_size * n_chains
  perc_badness <- round(100 * n_bad / n_evaluations)
  
  # report on progress
  print(sprintf("burst %i of %i (%i%s bad)",
                burst,
                n_bursts,
                perc_badness,
                "%"))
  
}
#> [1] "burst 1 of 20 (0% bad)"
#> [1] "burst 2 of 20 (0% bad)"
#> [1] "burst 3 of 20 (0% bad)"
#> [1] "burst 4 of 20 (0% bad)"
#> [1] "burst 5 of 20 (0% bad)"
#> [1] "burst 6 of 20 (0% bad)"
#> [1] "burst 7 of 20 (0% bad)"
#> [1] "burst 8 of 20 (0% bad)"
#> [1] "burst 9 of 20 (0% bad)"
#> [1] "burst 10 of 20 (0% bad)"
#> [1] "burst 11 of 20 (0% bad)"
#> [1] "burst 12 of 20 (0% bad)"
#> [1] "burst 13 of 20 (0% bad)"
#> [1] "burst 14 of 20 (0% bad)"
#> [1] "burst 15 of 20 (0% bad)"
#> [1] "burst 16 of 20 (0% bad)"
#> [1] "burst 17 of 20 (0% bad)"
#> [1] "burst 18 of 20 (0% bad)"
#> [1] "burst 19 of 20 (0% bad)"
#> [1] "burst 20 of 20 (0% bad)"

# plot the trace
list <- apply(trace, 2, coda::as.mcmc, simplify = FALSE)
draws <- coda::as.mcmc.list(list)
par(mfrow = c(2, 4))
plot(draws, auto.layout = FALSE)

Created on 2025-03-06 with reprex v2.1.1

Session info

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.2 (2024-10-31)
#>  os       macOS Sequoia 15.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2025-03-06
#>  pandoc   3.2.1 @ /opt/homebrew/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>  callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>  cli           3.6.4      2025-02-13 [1] CRAN (R 4.4.1)
#>  coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>  codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.2)
#>  crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>  curl          6.2.0      2025-01-23 [1] CRAN (R 4.4.1)
#>  digest        0.6.37     2024-08-19 [1] CRAN (R 4.4.1)
#>  evaluate      1.0.1      2024-10-10 [1] CRAN (R 4.4.1)
#>  fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>  fs            1.6.5      2024-10-30 [1] CRAN (R 4.4.1)
#>  future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>  globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue          1.8.0      2024-09-30 [1] CRAN (R 4.4.1)
#>  greta       * 0.5.0.9000 2025-01-24 [1] local
#>  hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>  htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>  jsonlite      1.8.9      2024-09-20 [1] CRAN (R 4.4.1)
#>  knitr         1.49       2024-11-08 [1] CRAN (R 4.4.1)
#>  lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.2)
#>  lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>  Matrix        1.7-1      2024-10-18 [2] CRAN (R 4.4.2)
#>  parallelly    1.41.0     2024-12-18 [1] CRAN (R 4.4.1)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>  png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>  prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>  processx      3.8.5      2025-01-08 [1] CRAN (R 4.4.1)
#>  progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>  ps            1.8.1      2024-10-28 [1] CRAN (R 4.4.1)
#>  R6            2.6.1      2025-02-15 [1] CRAN (R 4.4.1)
#>  Rcpp          1.0.14     2025-01-12 [1] CRAN (R 4.4.1)
#>  reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>  reticulate    1.40.0     2024-11-15 [1] CRAN (R 4.4.1)
#>  rlang         1.1.5      2025-01-17 [1] CRAN (R 4.4.1)
#>  rmarkdown     2.29       2024-11-04 [1] CRAN (R 4.4.1)
#>  rstudioapi    0.17.1     2024-10-22 [1] CRAN (R 4.4.1)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>  tensorflow  * 2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>  tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>  tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>  vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>  whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>  withr         3.0.2      2024-10-28 [1] CRAN (R 4.4.1)
#>  xfun          0.50.5     2025-01-15 [1] Github (yihui/xfun@116d689)
#>  xml2          1.3.6      2023-12-04 [1] CRAN (R 4.4.0)
#>  yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment