Created
October 5, 2014 18:11
-
-
Save soonraah/be02066bc45634df036d to your computer and use it in GitHub Desktop.
To compare EM algorithm and MCMC on GMM training.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn import cross_validation, mixture | |
import pickle | |
import os | |
import pystan | |
import time | |
import matplotlib.pyplot as plt | |
def dump_stan_model(stan_model, compiled_file_name): | |
""" | |
Dump compiled Stan model. | |
:param stan_model: compiled stan model instance | |
:param compiled_file_name: pickled file name (output) | |
""" | |
f = open(compiled_file_name, 'wb') | |
with f: | |
pickle.dump(stan_model, f) | |
def load_stan_model(compiled_file_name): | |
""" | |
Load compiled Stan model. | |
:param compiled_file_name: pickled stan model file | |
:return: loaded stan model instance | |
""" | |
f = open(compiled_file_name, 'rb') | |
with f: | |
stan_model = pickle.load(f) | |
return stan_model | |
def convert_model(stan_gmm_model): | |
""" | |
Convert a GMM model from Stan format to scikit-learn format. | |
:param stan_gmm_model: Stan's optimized result | |
:return: mixture.GMM instance from scikit-learn | |
""" | |
num_mixture_components = stan_gmm_model.get('weights').size | |
gmm = mixture.GMM(n_components=num_mixture_components, covariance_type='diag') | |
gmm.weights_ = stan_gmm_model.get('weights') | |
gmm.means_ = stan_gmm_model.get('mu') | |
gmm.covars_ = np.square(stan_gmm_model.get('sigma')) | |
return gmm | |
def draw_result_graph(likelihoods_em, likelihoods_mcmc): | |
""" | |
Draw a graph that shows likelihood of EM vs. MCMC | |
:param likelihoods_em: | |
:param likelihoods_mcmc: | |
:return: | |
""" | |
fix, ax = plt.subplots() | |
ax.scatter(likelihoods_em, likelihoods_mcmc, marker='o') | |
ax.plot([-7.5, -6.0], [-7.5, -6.0], color='gray', alpha=0.5) | |
plt.xlabel("Average Log Likelihood (EM)") | |
plt.ylabel("Average Log Likelihood (MCMC)") | |
plt.show() | |
def main(): | |
# prepare data | |
data_file_name = 'winequality-white.csv' | |
raw_data_set = np.loadtxt(data_file_name, delimiter=";", skiprows=1) | |
data_set = raw_data_set[:, :11] # remove "quority" column | |
# load stan model | |
stan_code_file_name = 'multi_dimensional_gmm_diagonal.stan' | |
stan_compiled_file_name = 'multi_dimensional_gmm_diagonal.pkl' | |
if os.path.isfile(stan_compiled_file_name): | |
stan_model = load_stan_model(stan_compiled_file_name) | |
else: | |
stan_model = pystan.StanModel(file=stan_code_file_name) | |
dump_stan_model(stan_model, stan_compiled_file_name) | |
# cross validation | |
num_validations = 500 | |
num_mixture_components = 4 | |
cnt = 0 | |
time_sec_em = 0.0 | |
time_sec_mcmc = 0.0 | |
likelihoods_em = [] | |
likelihoods_mcmc = [] | |
ss = cross_validation.ShuffleSplit(n=len(data_set.data), n_iter=num_validations, test_size=0.5) | |
for training_indexes, evaluation_indexes in ss: | |
cnt += 1 | |
print("--------------------------------") | |
print("ITERATION {0}".format(cnt)) | |
print("--------------------------------") | |
# separate data by ShuffleSplit results | |
tr_data_set = data_set[training_indexes] | |
ev_data_set = data_set[evaluation_indexes] | |
# run EM algorithm by scikit-learn | |
gmm_em = mixture.GMM(n_components=num_mixture_components, covariance_type='diag') | |
t = time.time() | |
gmm_em.fit(tr_data_set) | |
time_sec_em += time.time() - t | |
likelihoods_em.append(gmm_em.score(ev_data_set).mean()) | |
# run MCMC by PyStan | |
data_dic = dict(D=tr_data_set.shape[1], N=tr_data_set.shape[0], M=num_mixture_components, X=tr_data_set) | |
t = time.time() | |
optimizing_result = stan_model.optimizing(data=data_dic, iter=20000) | |
time_sec_mcmc += time.time() - t | |
gmm_mcmc = convert_model(optimizing_result) | |
likelihoods_mcmc.append(gmm_mcmc.score(ev_data_set).mean()) | |
print("--------------------------------") | |
print("COMPLETED") | |
print("--------------------------------") | |
print("likelihoods_em:", likelihoods_em) | |
print("likelihoods_mcmc:", likelihoods_mcmc) | |
print("avg time em: {0:.3f} sec".format(time_sec_em / num_validations)) | |
print("avg time mcmc: {0:.3f} sec".format(time_sec_mcmc / num_validations)) | |
draw_result_graph(likelihoods_em, likelihoods_mcmc) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment