Skip to content

Instantly share code, notes, and snippets.

@al6x
Created January 11, 2025 09:20
Show Gist options
  • Save al6x/7808b1d0cd936689f361f5dcf5e3a751 to your computer and use it in GitHub Desktop.
Save al6x/7808b1d0cd936689f361f5dcf5e3a751 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from scipy.stats import norm
def fit_normal_mixture(*, n_components, values, random_state, n_init):
values = np.array(values).reshape(-1, 1) # Convert to 2D array
nmm = GaussianMixture(n_components, covariance_type='diag', random_state=random_state, n_init=n_init)
nmm.fit(values)
means = nmm.means_.flatten().tolist()
sigmas = np.sqrt(nmm.covariances_.flatten()).tolist()
weights = nmm.weights_.flatten().tolist()
return weights, means, sigmas
def sample_normal_mixture(*, weights, means, sigmas, n):
if not np.isclose(sum(weights), 1):
raise ValueError("Weights must sum to 1")
components = np.random.choice(len(weights), size=n, p=weights)
return np.random.normal(loc=np.array(means)[components], scale=np.array(sigmas)[components])
# Plotting function
def plot_mixture(weights, means, sigmas, label, color, ax):
x = np.linspace(-10, 10, 1000)
y = np.zeros_like(x)
for weight, mean, sigma in zip(weights, means, sigmas):
y += weight * norm.pdf(x, loc=mean, scale=sigma)
ax.plot(x, y, label=label, color=color)
# Test
if __name__ == '__main__':
# Generating sample from Normal Mixture Model
original_weights = [0.5, 0.5]
original_means = [0, 0]
original_sigmas = [1, 2]
n_samples = 20000
nmm_sample = sample_normal_mixture(weights=original_weights, means=original_means, sigmas=original_sigmas, n=n_samples)
# Fitting sample to Normal Mixture Model
fitted_weights, fitted_means, fitted_sigmas = fit_normal_mixture(
n_components=2, values=nmm_sample, random_state=0, n_init=10
)
print("Original parameters:")
print(f"Weights: {original_weights}, Means: {original_means}, Sigmas: {original_sigmas}")
print("Fitted parameters:")
print(f"Weights: {fitted_weights}, Means: {fitted_means}, Sigmas: {fitted_sigmas}")
# Plotting
fig, ax = plt.subplots(figsize=(10, 6))
# Plot original model
plot_mixture(original_weights, original_means, original_sigmas, label="Original Model", color="blue", ax=ax)
# Plot fitted model
plot_mixture(fitted_weights, fitted_means, fitted_sigmas, label="Fitted Model", color="red", ax=ax)
# Histogram of samples
ax.hist(nmm_sample, bins=100, density=True, alpha=0.5, color='gray', label='Sample Histogram')
ax.set_title("Original and Fitted Normal Mixture Models")
ax.set_xlabel("Value")
ax.set_ylabel("Density")
ax.legend()
plt.show()
@al6x
Copy link
Author

al6x commented Jan 11, 2025

Also bayessian estimation, also produced same wrong results

import numpy as np
import matplotlib.pyplot as plt
import pymc as pm
from scipy.stats import norm

def fit_normal_mixture(*, n_components, values):
    values = np.array(values)  # Keep as 1D array for PyMC

    with pm.Model() as model:
        weights = pm.Dirichlet('weights', a=np.ones(n_components))
        means = pm.Normal('means', mu=0, sigma=10, shape=n_components)
        sigmas = pm.HalfNormal('sigmas', sigma=10, shape=n_components)

        # Define mixture model
        mixture = pm.NormalMixture('mixture', w=weights, mu=means, sigma=sigmas, observed=values)

        # Fit the model using MCMC sampling
        trace = pm.sample(1000, return_inferencedata=True, random_seed=0, tune=1000)

    # Extract fitted parameters
    weights_fitted = trace.posterior['weights'].mean(dim=('chain', 'draw')).values
    means_fitted = trace.posterior['means'].mean(dim=('chain', 'draw')).values
    sigmas_fitted = trace.posterior['sigmas'].mean(dim=('chain', 'draw')).values

    # Convert to lists for output
    return weights_fitted.tolist(), means_fitted.tolist(), sigmas_fitted.tolist()

def sample_normal_mixture(*, weights, means, sigmas, n):
    if not np.isclose(sum(weights), 1):
        raise ValueError("Weights must sum to 1")
    components = np.random.choice(len(weights), size=n, p=weights)
    return np.random.normal(loc=np.array(means)[components], scale=np.array(sigmas)[components])

# Plotting function
def plot_mixture(weights, means, sigmas, label, color, ax):
    x = np.linspace(-10, 10, 1000)
    y = np.zeros_like(x)

    for weight, mean, sigma in zip(weights, means, sigmas):
        y += weight * norm.pdf(x, loc=mean, scale=sigma)

    ax.plot(x, y, label=label, color=color)

# Test
if __name__ == '__main__':
    # Generating sample from Normal Mixture Model
    original_weights = [0.5, 0.5]
    original_means = [0, 0]
    original_sigmas = [1, 2]
    n_samples = 20000

    nmm_sample = sample_normal_mixture(weights=original_weights, means=original_means, sigmas=original_sigmas, n=n_samples)

    # Fitting sample to Normal Mixture Model using PyMC
    fitted_weights, fitted_means, fitted_sigmas = fit_normal_mixture(
        n_components=2, values=nmm_sample
    )

    print("Original parameters:")
    print(f"Weights: {original_weights}, Means: {original_means}, Sigmas: {original_sigmas}")

    print("Fitted parameters:")
    print(f"Weights: {fitted_weights}, Means: {fitted_means}, Sigmas: {fitted_sigmas}")

    # Plotting
    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot original model
    plot_mixture(original_weights, original_means, original_sigmas, label="Original Model", color="blue", ax=ax)

    # Plot fitted model
    plot_mixture(fitted_weights, fitted_means, fitted_sigmas, label="Fitted Model", color="red", ax=ax)

    # Histogram of samples
    ax.hist(nmm_sample, bins=100, density=True, alpha=0.5, color='gray', label='Sample Histogram')

    ax.set_title("Original and Fitted Normal Mixture Models")
    ax.set_xlabel("Value")
    ax.set_ylabel("Density")
    ax.legend()

    plt.show()

@al6x
Copy link
Author

al6x commented Jan 11, 2025

And Julia

using Distributions
using GaussianMixtures

nmm = MixtureModel(Normal[
  Normal(0.0, 1.0),
  Normal(0.0, 2.0)
],
  [0.5, 0.5]
)

sample = rand(nmm, 1000)

m = GMM(2, sample; method=:kmeans)
print(m)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment