Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Created March 24, 2025 06:23
Show Gist options
  • Save abikoushi/9aa5622a9cd216bf7f43dacb554ae5fd to your computer and use it in GitHub Desktop.
Save abikoushi/9aa5622a9cd216bf7f43dacb554ae5fd to your computer and use it in GitHub Desktop.
An Example of Stochastic Variational Bayes method: Mixture of Poisson distribution
# Reference:
# Hoffman et al. "Stochastic Variational Inference"
# browseURL("https://arxiv.org/abs/1206.7051")
####
# learning rate
lr <-function(t, lr_param){
(t + lr_param[1])^(-lr_param[2])
}
#curve(calc_rho(x), 0,12)
#toy data
rmixtpois <- function(n, lambda, comp){
k = sample.int(length(comp), size = n, replace = TRUE, prob = comp)
rpois(n, lambda[k])
}
rdirichlet <- function(a){
k = length(a)
x = rgamma(k,a)
x/sum(x)
}
row_softmax <- function(x){
mx = apply(x, 1, max)
ex = exp(sweep(x,1, mx))
sweep(ex, 1, rowSums(ex), "/")
}
logsumexp <- function(x){
mx <- max(x)
mx + log(sum(exp(x-mx)))
}
lp_pois <- function(y, lambda, loglambda){
y*loglambda - lambda
}
svb_mixpois <- function(y, K, n_batches, maxit,lr_param,
prior_alpha = 1, prior_beta = 1, prior_gamma = 1){
N = length(y)
# randomly initialize global parameters
lambda = rgamma(K, prior_alpha, prior_beta)
loglambda = log(lambda)
logphi = log(rdirichlet(rep(prior_gamma,K)))
#variational parameters
alpha_t <- rep(prior_alpha, K)
beta_t <- rep(prior_beta, K)
gamma_t <- rep(prior_gamma, K)
logprob = numeric(maxit)
pb = txtProgressBar(max=maxit, style=3)
for(t in seq_len(maxit)){
idx <- matrix(sample.int(N), nrow=n_batches)
lp = 0
S = ncol( idx )
SN = N/n_batches
for(s in seq_len(S)){
Z = matrix(0, nrow = n_batches, ncol = K)
for(j in seq_len(K)){
Z[,j] <- lp_pois(y[idx[,s]], lambda[j], loglambda[j]) + logphi[j]
}
lp = lp + sum(apply(Z, 1, logsumexp))
Z = row_softmax(Z)
sumZ <- colSums(Z)
a_new <- SN*colSums(sweep(Z, 1, y[idx[,s]],"*")) + prior_alpha
b_new <- SN*sumZ + prior_beta
gamma_new <- SN*sumZ + prior_gamma
rho = lr(t, lr_param = lr_param)
rho2 <- 1 - rho
rho = rho / n_batches
alpha_t <- alpha_t * rho2 + a_new * rho
beta_t <- beta_t * rho2 + b_new * rho
gamma_t <- gamma_t * rho2 + gamma_new * rho
lambda = alpha_t / beta_t
loglambda = digamma(alpha_t) - log(beta_t)
logphi = digamma(gamma_t) - digamma(sum(gamma_t))
setTxtProgressBar(pb, t)
}
logprob[t] = lp
}
list(alpha = alpha_t,
beta = beta_t,
gamma = gamma_t,
logprob = logprob)
}
mean_gamma <- function(out){
out$alpha/out$beta
}
mean_log_gamma <- function(out){
digamma(out$alpha) - log(out$beta)
}
mean_dirichlet <- function(out){
out$gamma / sum(out$gamma)
}
mean_log_dirichlet <- function(out){
digamma(out$gamma) - digamma(sum(out$gamma))
}
####
set.seed(198)
size = 100
y = rmixtpois(size, c(1,10), c(0.4,0.6))
K = 2 # number of component
out = svb_mixpois(y, K, n_batches = 20, maxit=50, lr_param=c(10, 0.9))
plot(out$logprob[-1], type = "l")
print(mean_dirichlet(out))
postprob_cluster <- function(out, y){
K <- length(out$gamma)
lambda = mean_gamma(out)
loglambda = mean_log_gamma(out)
loglambda = mean_log_gamma(out)
logphi = mean_log_dirichlet(out)
Z = matrix(0, nrow = length(y), ncol = K)
for(j in seq_len(K)){
Z[,j] <- lp_pois(y, lambda[j], loglambda[j]) + logphi[j]
}
row_softmax(Z)
}
Zhat = postprob_cluster(out, y)
df <- data.frame(y=y, cluster = Zhat[,1])
lambdahat <- mean_gamma(out)
library(ggplot2)
ggplot(df, aes(x=y, colour=cluster, group = cluster))+
geom_histogram(bins = 15, fill="white")+
geom_vline(xintercept = lambdahat, linetype=2)+
scale_color_binned(type = "viridis")
#ggsave("hist.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment