Created
April 7, 2025 21:31
-
-
Save vwxyzjn/635a2abb234a5e665767e175b7ddb21b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def stable_softmax(x, axis=None): | |
"""taken from scipy.special.softmax""" | |
x_max = np.amax(x, axis=axis, keepdims=True) | |
exp_x_shifted = np.exp(x - x_max) | |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) | |
def get_prob(arr: np.ndarray, temp: float) -> np.ndarray: | |
return stable_softmax(arr / max(temp, 1e-6)) | |
def gumbel_sample_without_replacement(logits, n: int, temperature=1.0, eps=1e-10): | |
""" | |
Sample n items without replacement using the Gumbel-max trick. | |
Args: | |
logits: Unnormalized log probabilities | |
n: Number of items to sample (must be <= len(logits)) | |
temperature: Temperature parameter (lower = more deterministic) | |
eps: Small constant for numerical stability | |
Returns: | |
Array of n indices sampled from the categorical distribution without replacement | |
""" | |
if n > len(logits): | |
raise ValueError(f"Cannot sample {n} items from a distribution of size {len(logits)}") | |
# Scale logits by temperature | |
scaled_logits = logits / max(temperature, eps) | |
# Sample from Gumbel(0, 1) | |
u = np.random.random(scaled_logits.shape) | |
g = -np.log(-np.log(u + eps) + eps) | |
# Add Gumbel noise to logits | |
noisy_logits = scaled_logits + g | |
# Get the indices of the n largest values | |
# This gives us n samples without replacement | |
indices = np.argpartition(noisy_logits, -n)[-n:] | |
# Sort the indices by their noisy logit values (highest to lowest) | |
# This maintains the correct sampling probabilities | |
sorted_indices = indices[np.argsort(-noisy_logits[indices])] | |
return sorted_indices | |
def sample_based_on_success_rate(success_rate, temperature, n=2): | |
""" | |
Sample indices based on success rate and temperature. | |
Args: | |
success_rate: Array of success rates | |
temperature: Temperature parameter (lower = more deterministic) | |
n: Number of items to sample without replacement | |
Returns: | |
Tuple containing: | |
- Indices sampled using gumbel_sample_without_replacement | |
- Probability distribution from get_prob | |
""" | |
logits = 1 - success_rate | |
sampled_indices = gumbel_sample_without_replacement(logits, n, temperature=temperature) | |
probabilities = get_prob(logits, temperature) | |
return sampled_indices, probabilities | |
# Example usage | |
if __name__ == "__main__": | |
success_rate = np.array([1.0, 1.0, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]) | |
# With temperature 0.0 (deterministic) | |
temperature = 0.0 | |
print(f"success_rate: {success_rate}") | |
print("-" * 100) | |
indices, probs = sample_based_on_success_rate(success_rate, temperature) | |
print(f"temperature: {temperature}") | |
print(f"gumbel_sample_without_replacement: {indices}") | |
print(f"get_prob: {probs}") | |
print("-" * 100) | |
# With temperature 1.0 (more random) | |
temperature = 1.0 | |
indices, probs = sample_based_on_success_rate(success_rate, temperature) | |
print(f"temperature: {temperature}") | |
print(f"gumbel_sample_without_replacement: {indices}") | |
print(f"get_prob: {probs}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment