Last active
January 11, 2024 14:49
-
-
Save kgourgou/875a5fcdada7b44fd66fbd4c3929ce38 to your computer and use it in GitHub Desktop.
Entropy estimator. Source: https://x.com/adad8m/status/1745109776458187138?s=20
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
def ent(X): | |
"""Calculates the entropy of a dataset X.""" | |
N, D = X.shape # Get the number of samples and dimensions | |
# Compute pairwise squared distances | |
dist_sq = jnp.sum((X[:, jnp.newaxis, :] - X[jnp.newaxis, :, :]) ** 2, axis=-1) | |
# Set the diagonal to a large number so it doesn't affect the min calculation | |
dist_sq = fill_diagonal(dist_sq, jnp.inf) | |
# Find the minimum distance for each point | |
min_dist = jnp.sqrt(jnp.min(dist_sq, axis=1)) | |
# Kozachenko-Leonenko estimator of the entropy (up to irrelevant constant) | |
return jnp.mean(jnp.log((N - 1) * min_dist ** D)) | |
@jax.jit | |
def functional(X): | |
"""Calculates the functional value for a dataset X.""" | |
return -ent(X) + jnp.mean(-log_density(X)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment