Created
October 19, 2021 16:10
-
-
Save cscherrer/f2788b7ab62a232eb42a3068fe2fe56c to your computer and use it in GitHub Desktop.
Modified Soss tests from AbstractGPs.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Soss, MeasureTheory, AbstractGPs, Test, SampleChainsDynamicHMC | |
@testset "Soss" begin | |
@testset "GP regression" begin | |
k = SqExponentialKernel() | |
y = randn(3) | |
X = randn(3, 1) | |
x = [rand(1) for _ in 1:3] | |
gp_regression = Soss.@model X begin | |
# Priors. | |
logα ~ Normal(0.0, 0.1) | |
logρ ~ Normal(0.0, 1.0) | |
logσ² ~ Normal(0.0, 1.0) | |
α = exp(logα) | |
ρ = exp(logρ) | |
σ² = exp(logσ²) | |
# Realized covariance function | |
kernel = α * (SqExponentialKernel() ∘ ScaleTransform(1 / ρ)) | |
f = GP(kernel) | |
# Sampling Distribution. | |
y ~ f(X, σ² + 1e-9) | |
end | |
# Test for matrices | |
m = gp_regression(; X=RowVecs(X)) | |
@test length(Soss.sample((m | (y=y,)), dynamichmc(), 5, 1)) == 5 | |
# Test for vectors of vector | |
m = gp_regression(; X=x) | |
@test length(Soss.sample((m | (y=y,)), dynamichmc(), 5, 1)) == 5 | |
end | |
@testset "latent GP regression" begin | |
X = randn(3, 1) | |
x = [rand(1) for _ in 1:3] | |
y = rand.(Poisson.(exp.(randn(3)))) | |
latent_gp_regression = Soss.@model X begin | |
f = GP(Matern32Kernel()) | |
u ~ f(X) | |
λ = exp.(u) | |
y ~ For(eachindex(λ)) do i | |
Poisson(λ[i]) | |
end | |
end | |
m = latent_gp_regression(; X=RowVecs(X)) | |
@test length(Soss.sample((m | (y=y,)), dynamichmc(), 5, 1)) == 5 | |
# Test for vectors of vector | |
m = latent_gp_regression(; X=x) | |
@test length(Soss.sample((m | (y=y,)), dynamichmc(), 5, 1)) == 5 | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment