Last active
October 19, 2024 23:02
-
-
Save CarloLucibello/c3f3196f3ed89bbc0f296151f32dba0e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra, Random, Statistics | |
using SpecialFunctions: erf, erfinv | |
using Primes: Primes | |
using BenchmarkTools | |
using MvNormalCDF: MvNormalCDF | |
using Test | |
# implements https://www.math.wsu.edu/faculty/genz/papers/mvn.pdf | |
# with the addition of quasi-monte carlo point sampling | |
# as discussed in the Genz's book and implemented in MvNormalCDF.jl | |
ϕ(x) = 0.5 * (1 + erf(x / √2)) | |
ϕinv(y) = √2 * erfinv(2*y - 1) | |
struct KahanSum | |
val::Float64 # The running total | |
c::Float64 # The correction term | |
end | |
KahanSum(val::Float64) = KahanSum(val, 0.0) | |
function Base.:+(k::KahanSum, x::Float64) | |
y = x - k.c | |
t = k.val + y | |
c = (t - k.val) - y | |
return KahanSum(t, c) | |
end | |
""" | |
∫D(f, Σ, a, b; [rtol, atol, maxevals, warntol]) | |
∫D(f, a, b; [rtol, atol, maxevals, warntol]) | |
∫D(f, n; [rtol, atol, maxevals, warntol]) | |
Compute the integral of a function `f` weighted by a multivariate normal distribution | |
with covariance matrix `Σ`. The integral is computed over the hyper-rectangle defined by | |
the lower bounds `a` and upper bounds `b`. The function uses quasi-monte carlo sampling | |
to estimate the integral. | |
The iteration are stopped early if the relative error is the estimated error `err` | |
is `err < max(rtol * abs(value), atol)`. In any case, no more than `maxevals` samples are used. | |
If `Σ` is not given, it is assumed to be the identity matrix. | |
If `a` and `b` are not given, standard Gaussian integration in dimension `n` is performed. | |
Returns a tuple `(value, error)` with the estimated value of the integral and the estimated error. | |
# Arguments | |
- `f::Function`: function to integrate. The function should accept a vector of length `d` and return a scalar. | |
- `Σ::AbstractMatrix`: d x d covariance matrix for the multivariate normal distribution. | |
- `a::AbstractVector`: lower bounds for the integration domain. A vector of length `d`. | |
- `b::AbstractVector`: upper bounds for the integration domain. A vector of length `d`. | |
- `n::Int`: dimension of the integration space. | |
- `rtol::Float64`: relative tolerance for the integral. Default is `1e-6`. | |
- `atol::Float64`: absolute tolerance for the integral. Default is `0.0`. | |
- `maxevals::Int`: maximum number of samples to use. Default is `10^7 * length(a)`. | |
- `warntol::Bool`: if true, a warning is printed if the required tolerance is not reached. Default is `true`. | |
""" | |
function ∫D( | |
g::F, # function to integrate | |
Σ::AbstractMatrix, # covariance matrix | |
a::AbstractVector, # lower bounds | |
b::AbstractVector; # upper bounds | |
rtol::Float64 = 1e-5, # relative tolerance | |
atol::Float64 = 0., # absolute tolerance | |
maxevals::Int = 10^7 * length(a), # max number of MC samples | |
warntol::Bool = true, | |
) where F | |
α = 3.0 # monte carlo confidence level | |
dim = length(a) | |
@assert rtol > 0 || atol > 0 | |
@assert length(a) == length(b) | |
@assert size(Σ) == (dim, dim) | |
@assert all(a .<= b) | |
if any(a .== b) | |
return 0.0, 0.0 | |
end | |
n_outer = 100 # num outer loops | |
n_inner = max(maxevals ÷ n_outer, 1) # num inner loops | |
intsum = 0.0 | |
varsum = 0.0 | |
C = cholesky(Σ).U # use upper instead of lower for friendly indexing | |
# preallocate | |
y, w, Δ, x = zeros(dim), zeros(dim), zeros(dim), zeros(dim) | |
primes = Primes.primes(Int(floor(5 * dim * log(dim + 1) / 4))) # Richtmyer generators | |
q = sqrt.(primes[1:dim]) | |
d1 = ϕ(a[1] / C[1,1]) | |
e1 = ϕ(b[1] / C[1,1]) | |
f1 = e1 - d1 | |
err = 0.0 | |
value = 0.0 | |
for i in 1:n_outer | |
rand!(Δ) | |
intsum = KahanSum(0.0) | |
@inbounds for j in 1:n_inner | |
@. w = abs(2*((j * q + Δ) % 1) - 1) | |
for k=1:dim # without this we have numerical issues | |
if w[k] == 0 | |
w[k] = 1e-10 | |
elseif w[k] == 1 | |
w[k] = 1 - 1e-10 | |
end | |
end | |
d, e, f = d1, e1, f1 | |
x[1] = 0 | |
for k in 2:dim | |
y[k-1] = ϕinv(d + w[k-1] * (e - d)) | |
cy = 0.0 | |
for j=1:k-1 | |
cy += C[j,k] * y[j] | |
end | |
d = ϕ((a[k] - cy) / C[k,k]) | |
e = ϕ((b[k] - cy) / C[k,k]) | |
f = (e - d) * f | |
# last step for the computation of x[i] | |
x[k] = cy | |
end | |
# last iteration for the computation of x[i] | |
y[dim] = ϕinv(d + w[dim] * (e - d)) | |
for k=1:dim | |
x[k] += C[k,k] * y[k] | |
end | |
gx = g(x) | |
intsum += (f * gx - intsum.val) / j | |
@assert isfinite(intsum.val) | |
end | |
δ = (intsum.val - value) / i | |
value += δ | |
varsum = (i-2) * varsum / i + δ^2 | |
err = α * √varsum | |
if abs(err) < atol || abs(err) < rtol * abs(value) | |
break | |
end | |
end | |
if warntol && !(abs(err) < atol || abs(err) < rtol * abs(value)) | |
@warn("Required tolerance atol=$atol rtol=$rtol not reached. Current error: $err. Consider increasing maxevals.") | |
end | |
return value, err | |
end | |
# specialization for standard Gaussian and full domain | |
∫D(f, n::Int; kws...) = ∫D(f, fill(-Inf, n), fill(Inf, n); kws...) | |
# specialization for standard Gaussian | |
function ∫D( | |
g::Function, # function to integrate | |
a::AbstractVector, # lower bounds | |
b::AbstractVector; # upper bounds | |
rtol::Float64 = 1e-5, # relative tolerance | |
atol::Float64 = 0., # absolute tolerance | |
maxevals::Int = 10^7 * length(a), # max number of MC samples | |
warntol::Bool = true, | |
) | |
α = 3.0 # monte carlo confidence level | |
dim = length(a) | |
@assert length(a) == length(b) | |
@assert all(a .<= b) | |
if any(a .== b) | |
return 0.0, 0.0 | |
end | |
n_outer = 100 # num outer loops | |
n_inner = max(maxevals ÷ n_outer, 1) # num inner loops | |
intsum = 0.0 | |
varsum = 0.0 | |
# preallocate | |
y, w, Δ = zeros(dim), zeros(dim), zeros(dim) | |
primes = Primes.primes(Int(floor(5 * dim * log(dim + 1) / 4))) # Richtmyer generators | |
q = sqrt.(primes[1:dim]) | |
d = ϕ.(a) | |
e = ϕ.(b) | |
f = prod(e .- d) | |
err = 0.0 | |
value = 0.0 | |
for i in 1:n_outer | |
rand!(Δ) | |
intsum = KahanSum(0.0) | |
@inbounds for j in 1:n_inner | |
@. w = abs(2*((j * q + Δ) % 1) - 1) | |
for k=1:dim # without this we have numerical issues | |
if w[k] == 0 | |
w[k] = 1e-10 | |
elseif w[k] == 1 | |
w[k] = 1 - 1e-10 | |
end | |
end | |
@. y = ϕinv(d + w * (e - d)) | |
intsum += (f * g(y) - intsum.val) / j | |
@assert isfinite(intsum.val) | |
end | |
δ = (intsum.val - value) / i | |
value += δ | |
varsum = (i-2)*varsum / i + δ^2 | |
err = α * √varsum | |
if abs(err) < atol || abs(err) < rtol * abs(value) | |
break | |
end | |
end | |
if warntol && !(abs(err) < atol || abs(err) < rtol * abs(value)) | |
@warn(""" | |
Required tolerance atol=$atol rtol=$rtol not reached. | |
Currently val = $value ± $err. | |
Consider increasing maxevals. | |
""") | |
end | |
return value, err | |
end | |
@testset "compare to MvNormalCDF" begin | |
Σ = [4 3 2 1 | |
3 5 -1 1 | |
2 -1 4 2 | |
1 1 2 5] | |
D = [1 0 0 0 | |
0 1 0 0 | |
0 0 1 0 | |
0 0 0 1] | |
a = [-2, -3 ,0, 0] | |
b = [-1, -2, 4, 1] | |
maxevals = 10^7 | |
fone(x) = 1 | |
f(x) = log(1 + sum(x.^2)) | |
res1 = MvNormalCDF.mvnormcdf(Σ, a, b; m = maxevals)[1] | |
res2 = ∫D(fone, Σ, a, b; maxevals)[1] | |
@test res1 ≈ res2 atol=1e-6 | |
res1 = MvNormalCDF.mvnormcdf(D, a, b; m = maxevals)[1] | |
res2 = ∫D(fone, D, a, b; maxevals)[1] | |
res3 = ∫D(fone, a, b; maxevals)[1] | |
@test res1 ≈ res2 atol=1e-6 | |
@test res1 ≈ res3 atol=1e-6 | |
end | |
@testset "Standard Gaussian specialization" begin | |
D = [1 0 0 0 | |
0 1 0 0 | |
0 0 1 0 | |
0 0 0 1] | |
a = [-2, -3 ,0, 0] | |
b = [-1, -2, 4, 1] | |
maxevals = 10_000_000 | |
f(x) = log(1 + sum(x.^2)) * exp(0.2x[1]) | |
res1 = ∫D(f, a, b; maxevals)[1] | |
res2 = ∫D(f, D, a, b; maxevals)[1] | |
@test res1 ≈ res2 rtol=1e-5 | |
res1 = ∫D(f, 4; maxevals)[1] | |
res2 = ∫D(f, [-Inf, -Inf, -Inf, -Inf], [Inf, Inf, Inf, Inf]; maxevals)[1] | |
@test res1 ≈ res2 rtol=1e-5 | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment