Created
November 16, 2024 00:19
-
-
Save dapurv5/a77d20aa61a9b51797951f65c76d1314 to your computer and use it in GitHub Desktop.
Approx KL Divergence (Courtesy John Schulman)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
def approx_kl_divergence( | |
p: torch.Tensor, q: torch.Tensor, num_samples: int = 10_000 | |
) -> float: | |
""" | |
Compute KL divergence between two categorical distributions using the k3 definition. | |
http://joschu.net/blog/kl-approx.html | |
Args: | |
p: First distribution as 1D tensor of probabilities that sum to 1 | |
q: Second distribution as 1D tensor of probabilities that sum to 1 | |
num_samples: Number of samples to use for estimation | |
This can be lower than the size of p and q | |
Returns: | |
float: Estimated KL divergence | |
""" | |
# Create categorical distributions | |
p_dist = torch.distributions.Categorical(probs=p) | |
q_dist = torch.distributions.Categorical(probs=q) | |
# Sample from q distribution | |
x = q_dist.sample(sample_shape=(num_samples,)) | |
# Compute log ratio of probabilities | |
logr = p_dist.log_prob(x) - q_dist.log_prob(x) | |
# Compute k3 estimate | |
k3 = (logr.exp() - 1) - logr | |
return k3.mean().item() | |
def example(): | |
p = F.softmax(torch.rand(1_000_00), dim=0) | |
q = F.softmax(torch.rand(1_000_00), dim=0) | |
kl = approx_kl_divergence(p, q) | |
print("Approx KL Divergence:", kl) | |
# Now compute actual KL divergence | |
p_dist = torch.distributions.Categorical(probs=p) | |
q_dist = torch.distributions.Categorical(probs=q) | |
true_kl = torch.distributions.kl.kl_divergence(p_dist, q_dist) | |
print("True KL Divergence:", true_kl.item()) | |
# Compute it manually analytically | |
kl_manual = (p * (p / q).log()).sum().item() | |
print("Manual KL Divergence:", kl_manual) | |
def main(): | |
example() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment