Skip to content

Instantly share code, notes, and snippets.

@dapurv5
Created November 16, 2024 00:19
Show Gist options
  • Save dapurv5/a77d20aa61a9b51797951f65c76d1314 to your computer and use it in GitHub Desktop.
Save dapurv5/a77d20aa61a9b51797951f65c76d1314 to your computer and use it in GitHub Desktop.
Approx KL Divergence (Courtesy John Schulman)
import torch
import torch.nn.functional as F
def approx_kl_divergence(
p: torch.Tensor, q: torch.Tensor, num_samples: int = 10_000
) -> float:
"""
Compute KL divergence between two categorical distributions using the k3 definition.
http://joschu.net/blog/kl-approx.html
Args:
p: First distribution as 1D tensor of probabilities that sum to 1
q: Second distribution as 1D tensor of probabilities that sum to 1
num_samples: Number of samples to use for estimation
This can be lower than the size of p and q
Returns:
float: Estimated KL divergence
"""
# Create categorical distributions
p_dist = torch.distributions.Categorical(probs=p)
q_dist = torch.distributions.Categorical(probs=q)
# Sample from q distribution
x = q_dist.sample(sample_shape=(num_samples,))
# Compute log ratio of probabilities
logr = p_dist.log_prob(x) - q_dist.log_prob(x)
# Compute k3 estimate
k3 = (logr.exp() - 1) - logr
return k3.mean().item()
def example():
p = F.softmax(torch.rand(1_000_00), dim=0)
q = F.softmax(torch.rand(1_000_00), dim=0)
kl = approx_kl_divergence(p, q)
print("Approx KL Divergence:", kl)
# Now compute actual KL divergence
p_dist = torch.distributions.Categorical(probs=p)
q_dist = torch.distributions.Categorical(probs=q)
true_kl = torch.distributions.kl.kl_divergence(p_dist, q_dist)
print("True KL Divergence:", true_kl.item())
# Compute it manually analytically
kl_manual = (p * (p / q).log()).sum().item()
print("Manual KL Divergence:", kl_manual)
def main():
example()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment