Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created April 7, 2025 21:31
Show Gist options
  • Save vwxyzjn/635a2abb234a5e665767e175b7ddb21b to your computer and use it in GitHub Desktop.
Save vwxyzjn/635a2abb234a5e665767e175b7ddb21b to your computer and use it in GitHub Desktop.
import numpy as np
def stable_softmax(x, axis=None):
"""taken from scipy.special.softmax"""
x_max = np.amax(x, axis=axis, keepdims=True)
exp_x_shifted = np.exp(x - x_max)
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
def get_prob(arr: np.ndarray, temp: float) -> np.ndarray:
return stable_softmax(arr / max(temp, 1e-6))
def gumbel_sample_without_replacement(logits, n: int, temperature=1.0, eps=1e-10):
"""
Sample n items without replacement using the Gumbel-max trick.
Args:
logits: Unnormalized log probabilities
n: Number of items to sample (must be <= len(logits))
temperature: Temperature parameter (lower = more deterministic)
eps: Small constant for numerical stability
Returns:
Array of n indices sampled from the categorical distribution without replacement
"""
if n > len(logits):
raise ValueError(f"Cannot sample {n} items from a distribution of size {len(logits)}")
# Scale logits by temperature
scaled_logits = logits / max(temperature, eps)
# Sample from Gumbel(0, 1)
u = np.random.random(scaled_logits.shape)
g = -np.log(-np.log(u + eps) + eps)
# Add Gumbel noise to logits
noisy_logits = scaled_logits + g
# Get the indices of the n largest values
# This gives us n samples without replacement
indices = np.argpartition(noisy_logits, -n)[-n:]
# Sort the indices by their noisy logit values (highest to lowest)
# This maintains the correct sampling probabilities
sorted_indices = indices[np.argsort(-noisy_logits[indices])]
return sorted_indices
def sample_based_on_success_rate(success_rate, temperature, n=2):
"""
Sample indices based on success rate and temperature.
Args:
success_rate: Array of success rates
temperature: Temperature parameter (lower = more deterministic)
n: Number of items to sample without replacement
Returns:
Tuple containing:
- Indices sampled using gumbel_sample_without_replacement
- Probability distribution from get_prob
"""
logits = 1 - success_rate
sampled_indices = gumbel_sample_without_replacement(logits, n, temperature=temperature)
probabilities = get_prob(logits, temperature)
return sampled_indices, probabilities
# Example usage
if __name__ == "__main__":
success_rate = np.array([1.0, 1.0, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0])
# With temperature 0.0 (deterministic)
temperature = 0.0
print(f"success_rate: {success_rate}")
print("-" * 100)
indices, probs = sample_based_on_success_rate(success_rate, temperature)
print(f"temperature: {temperature}")
print(f"gumbel_sample_without_replacement: {indices}")
print(f"get_prob: {probs}")
print("-" * 100)
# With temperature 1.0 (more random)
temperature = 1.0
indices, probs = sample_based_on_success_rate(success_rate, temperature)
print(f"temperature: {temperature}")
print(f"gumbel_sample_without_replacement: {indices}")
print(f"get_prob: {probs}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment