Skip to content

Instantly share code, notes, and snippets.

@jsta
Forked from ashiklom/missing.R
Created May 11, 2018 15:32
Show Gist options
  • Save jsta/d1a6abc9632b67dbf78ad8c2657d2db2 to your computer and use it in GitHub Desktop.
Save jsta/d1a6abc9632b67dbf78ad8c2657d2db2 to your computer and use it in GitHub Desktop.
Multivariate normal fit with missing values in STAN
library(rstan)
library(MASS)
# Simulate some data
mu <- c(10, 5, -5)
Sig <- matrix(c(1, 0.7, 0.8,
0.7, 2, 0.2,
0.8, 0.2, 1.5),
3, 3)
N <- 100
dat <- mvrnorm(N, mu, Sig)
# Randomly remove some data
nmiss <- 115
miss <- sample.int(prod(dim(dat)), size = nmiss)
dat[miss] <- NA
# Extract the missing values into a VECTOR
dat_complete <- dat[!is.na(dat)]
# Extract the missing and present values as MATRICES
ind_pres <- which(!is.na(dat), arr.ind = TRUE)
ind_miss <- which(is.na(dat), arr.ind = TRUE)
# STAN model
mod.nomiss <- '
data {
int<lower=0> Nrow;
int<lower=0> Ncol;
int<lower=0> Ncomp; // Number of non-missing values
int<lower=0> Nmiss; // Number of missing values
real dat_complete[Ncomp]; // Vector of non-missing values
int ind_pres[Ncomp, 2]; // Matrix (row, col) of non-missing value indices
int ind_miss[Nmiss, 2]; // Matrix (row, col) of missing value indices
}
parameters {
// Multivariate normal distribution parameters
vector[Ncol] mu;
cov_matrix[Ncol] Sigma;
// Vector containing "stochastic" nodes (for filling missing values
real ymiss[Nmiss];
}
transformed parameters {
vector[Ncol] y[Nrow]; // The "data" with interpolated missing values
// Fill y with non-missing values
for(n in 1:Ncomp) {
y[ind_pres[n,1]][ind_pres[n,2]] <- dat_complete[n];
}
// Fill the rest of y with missing value "parameters"
for(n in 1:Nmiss){
y[ind_miss[n,1]][ind_miss[n,2]] <- ymiss[n];
}
}
model {
for(i in 1:Nrow){
y[i] ~ multi_normal(mu, Sigma);
}
}
'
mod.data <- list(Nrow = nrow(dat),
Ncol = ncol(dat),
Ncomp = length(dat_complete),
Nmiss = sum(is.na(dat)),
dat_complete = dat_complete,
ind_pres = ind_pres,
ind_miss = ind_miss)
fit <- stan(model_code = mod.nomiss, data = mod.data,
iter = 1000, chains = 3, verbose=TRUE,
pars = c("mu", "Sigma"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment