Last active
March 28, 2024 20:03
-
-
Save torfjelde/37be5a672d29e473983b8e82b45c2e41 to your computer and use it in GitHub Desktop.
Converting output from `generated_quantities(model, chain)` into a `MCMCChains.Chains` object
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> using Turing | |
julia> include("utils.jl") | |
julia> @model function demo(xs) | |
s ~ InverseGamma(2, 3) | |
m ~ Normal(0, √s) | |
for i in eachindex(xs) | |
xs[i] ~ Normal(m, √s) | |
end | |
return (m = m, s = s) | |
end | |
demo (generic function with 1 method) | |
julia> xs = randn(100) .+ 1; | |
julia> m = demo(xs); | |
julia> chain = sample(m, MH(), MCMCThreads(), 100, 2); | |
┌ Warning: Only a single thread available: MCMC chains are not sampled in parallel | |
└ @ AbstractMCMC ~/.julia/packages/AbstractMCMC/iOkTf/src/sample.jl:197 | |
Sampling (1 threads) 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00 | |
julia> res = DynamicPPL.generated_quantities(m, chain); | |
julia> size(res) | |
(100, 2) | |
julia> Chains(res) | |
Chains MCMC chain (100×2×2 Array{Float64,3}): | |
Iterations = 1:100 | |
Thinning interval = 1 | |
Chains = 1, 2 | |
Samples per chain = 100 | |
parameters = m, s | |
Summary Statistics | |
parameters mean std naive_se mcse ess rhat | |
Symbol Float64 Float64 Float64 Float64 Float64 Float64 | |
m 0.8961 0.2833 0.0200 0.0538 5.4693 1.2846 | |
s 1.4348 1.4394 0.1018 0.1462 62.8920 1.0039 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
m -0.0137 0.7715 0.9439 1.0597 1.1702 | |
s 0.8233 1.0051 1.3981 1.4125 1.6766 | |
julia> # Or creating a chain by hand: | |
res = [(x1 = randn(), x2 = randn(2), x3 = randn(2, 2)) for i = 1:100]; | |
julia> Chains(res) | |
Chains MCMC chain (100×7×1 Array{Float64,3}): | |
Iterations = 1:100 | |
Thinning interval = 1 | |
Chains = 1 | |
Samples per chain = 100 | |
parameters = x1, x2[1], x2[2], x3[1,1], x3[2,1], x3[1,2], x3[2,2] | |
Summary Statistics | |
parameters mean std naive_se mcse ess rhat | |
Symbol Float64 Float64 Float64 Missing Float64 Float64 | |
x1 0.1048 1.0150 0.1015 missing 83.5494 0.9905 | |
x2[1] -0.0370 1.0645 0.1064 missing 52.3812 0.9975 | |
x2[2] -0.0174 1.1423 0.1142 missing 109.7089 0.9903 | |
x3[1,1] -0.1262 0.9917 0.0992 missing 215.2624 0.9927 | |
x3[2,1] -0.1030 0.8943 0.0894 missing 115.0757 0.9949 | |
x3[1,2] 0.1921 0.9276 0.0928 missing 126.8230 1.0107 | |
x3[2,2] 0.0725 1.0082 0.1008 missing 92.7707 0.9946 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
x1 -1.7972 -0.6828 0.2499 0.7517 1.9684 | |
x2[1] -2.1026 -0.8716 -0.0407 0.7425 2.3251 | |
x2[2] -2.2431 -0.8199 -0.0915 0.7274 2.1686 | |
x3[1,1] -1.9870 -0.8212 -0.2149 0.5845 1.7913 | |
x3[2,1] -1.6532 -0.8203 -0.1102 0.5373 1.5690 | |
x3[1,2] -1.5999 -0.3648 0.2584 0.8105 1.8125 | |
x3[2,2] -2.0730 -0.6014 0.0374 0.7459 2.0825 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
generate_names(val) = generate_names("", val) | |
generate_names(vn_str::String, val::Real) = [vn_str;] | |
function generate_names(vn_str::String, val::NamedTuple) | |
return map(keys(val)) do k | |
generate_names("$(vn_str)$(k)", val[k]) | |
end | |
end | |
function generate_names(vn_str::String, val::AbstractArray{<:Real}) | |
results = String[] | |
for idx in CartesianIndices(val) | |
s = join(idx.I, ",") | |
push!(results, "$vn_str[$s]") | |
end | |
return results | |
end | |
function generate_names(vn_str::String, val::AbstractArray{<:AbstractArray}) | |
results = String[] | |
for idx in CartesianIndices(val) | |
s1 = join(idx.I, ",") | |
inner_results = map(f("", val[idx])) do s2 | |
"$vn_str[$s1]$s2" | |
end | |
append!(results, inner_results) | |
end | |
return results | |
end | |
flatten(val::Real) = [val;] | |
function flatten(val::AbstractArray{<:Real}) | |
return mapreduce(vcat, CartesianIndices(val)) do i | |
val[i] | |
end | |
end | |
function flatten(val::AbstractArray{<:AbstractArray}) | |
return mapreduce(vcat, CartesianIndices(val)) do i | |
flatten(val[i]) | |
end | |
end | |
function vectup2chainargs(ts::AbstractVector{<:NamedTuple}) | |
ks = keys(first(ts)) | |
vns = mapreduce(vcat, ks) do k | |
generate_names(string(k), first(ts)[k]) | |
end | |
vals = map(eachindex(ts)) do i | |
mapreduce(vcat, ks) do k | |
flatten(ts[i][k]) | |
end | |
end | |
arr_tmp = reduce(hcat, vals)' | |
arr = reshape(arr_tmp, (size(arr_tmp)..., 1)) # treat as 1 chain | |
return Array(arr), vns | |
end | |
function vectup2chainargs(ts::AbstractMatrix{<:NamedTuple}) | |
num_samples, num_chains = size(ts) | |
res = map(1:num_chains) do chain_idx | |
vectup2chainargs(ts[:, chain_idx]) | |
end | |
vals = getindex.(res, 1) | |
vns = getindex.(res, 2) | |
# Verify that the variable names are indeed the same | |
vns_union = reduce(union, vns) | |
@assert all(isempty.(setdiff.(vns, Ref(vns_union)))) "variable names differ between chains" | |
arr = cat(vals...; dims = 3) | |
return arr, first(vns) | |
end | |
function MCMCChains.Chains(ts::AbstractArray{<:NamedTuple}) | |
return MCMCChains.Chains(vectup2chainargs(ts)...) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey Tor, this is great! Just curious if there's a reason why this isn't included in
Turing.jl
?