Skip to content

Instantly share code, notes, and snippets.

@yoniLavi
Last active June 30, 2025 19:43
Show Gist options
  • Save yoniLavi/54470b8ec9c17b21a75a8a1dff603019 to your computer and use it in GitHub Desktop.
Save yoniLavi/54470b8ec9c17b21a75a8a1dff603019 to your computer and use it in GitHub Desktop.
Discrete probability distribution simulator from first principles, written with the help of Claude Sonnet 4
"""
Discrete probability distribution simulator with proper dependency tracking.
This module implements discrete probability distributions from first principles,
correctly handling complex dependencies between random variables.
Examples:
>>> die = RandomVariable.die()
>>> coin = RandomVariable.coin()
# Independent variables: full Cartesian product
>>> (die + coin).possible_values()
[1, 2, 3, 4, 5, 6, 7]
# Same variable: dependent operations
>>> (die - die).possible_values() # Always zero
[0]
# Dependency tracking: (X + X) + X correctly recognizes all three X's are the same
>>> from fractions import Fraction
>>> x = RandomVariable("x", Distribution({1: Fraction(1, 2), 2: Fraction(1, 2)}))
>>> ((x + x) + x).possible_values() # 3*1=3 or 3*2=6
[3, 6]
# Test case that should work - mixed dependent/independent operations
>>> c = RandomVariable.coin()
>>> d = RandomVariable.die()
>>> result = c * (d + c)
>>> result.possible_values() # This should not raise ValueError
[0, 2, 3, 4, 5, 6, 7]
# Test exact arithmetic in operations
>>> die1 = RandomVariable.die()
>>> die2 = RandomVariable.die()
>>> sum_dice = die1 + die2
>>> sum_dice.expectation == 7.0 # Exact comparison to float
True
# Test dependent operations preserve exact arithmetic
>>> same_die = die1 + die1
>>> same_die.possible_values() # Should be the doubles of the originals
[2, 4, 6, 8, 10, 12]
>>> len(same_die) # Should have 6 outcomes, not 36
6
"""
from collections import defaultdict
from collections.abc import Collection, Iterator, ItemsView, KeysView, ValuesView
from fractions import Fraction
from math import isclose, sqrt
import itertools
import operator
class Distribution:
"""A probability distribution mapping values to their probabilities.
Examples:
>>> from fractions import Fraction
>>> dist = Distribution({1: Fraction(3, 10), 2: Fraction(7, 10)})
>>> dist.expectation() # 1.7
Fraction(17, 10)
>>> len(dist)
2
>>> uniform = Distribution.uniform([1, 2, 3, 4])
>>> uniform.expectation() == 2.5
True
# Test exact arithmetic - probabilities sum to exactly 1
>>> die = Distribution.uniform([1, 2, 3, 4, 5, 6])
>>> sum(die.probabilities.values()) == 1
True
>>> die.probabilities[1]
Fraction(1, 6)
# Test Bernoulli with exact fractions
>>> fair_coin = Distribution.bernoulli(Fraction(1, 2))
>>> fair_coin.probabilities[0] + fair_coin.probabilities[1] == 1
True
"""
def __init__(self, probabilities: dict[float, Fraction]):
# Store probabilities as exact Fractions
self.probabilities = probabilities
self.validate()
def validate(self) -> None:
"""Validate that probabilities sum to 1."""
total_prob = sum(self.probabilities.values())
if total_prob != 1:
raise ValueError(f"Probabilities must sum to 1, got {total_prob}")
def copy(self) -> "Distribution":
"""Create a copy of this distribution."""
return Distribution(self.probabilities)
def possible_values(self) -> list[float]:
"""Return sorted list of all possible values."""
return sorted(self.probabilities.keys())
def expectation(self) -> float:
"""Calculate the expectation (mean) of the distribution."""
return sum(value * prob for value, prob in self.probabilities.items())
@classmethod
def from_outcomes(cls, outcomes: dict["Outcome", Fraction]) -> "Distribution":
"""Create a Distribution from a dict of Outcomes to probabilities."""
probabilities: dict[float, Fraction] = defaultdict(Fraction)
for outcome, prob in outcomes.items():
probabilities[outcome.value] += prob
return cls(probabilities)
@classmethod
def uniform(cls, values: Collection[float]) -> "Distribution":
"""Create a uniform distribution over the given values."""
prob = Fraction(1, len(values))
return cls({value: prob for value in values})
@classmethod
def bernoulli(cls, p: Fraction) -> "Distribution":
"""Create a Bernoulli distribution with success probability p."""
return cls({0: 1 - p, 1: p})
@classmethod
def constant(cls, value: float) -> "Distribution":
"""Create a degenerate distribution (always returns the same value)."""
return cls({value: Fraction(1)})
def __getitem__(self, key: float) -> Fraction:
return self.probabilities[key]
def __iter__(self) -> Iterator[float]:
return iter(self.probabilities)
def items(self) -> ItemsView[float, Fraction]:
return self.probabilities.items()
def values(self) -> ValuesView[Fraction]:
return self.probabilities.values()
def keys(self) -> KeysView[float]:
return self.probabilities.keys()
def __len__(self) -> int:
"""Return the number of possible values."""
return len(self.probabilities)
def __bool__(self) -> bool:
"""Return True if distribution has values."""
return bool(self.probabilities)
def __repr__(self) -> str:
return f"Distribution({self.probabilities})"
class Outcome:
"""
An outcome in probability theory - a specific realization from the sample space.
Each outcome consists of:
- value: The observed value (what the random variable maps to)
- variable_assignments: Which primitive variables took which values to produce this outcome
"""
def __init__(self, value: float, variable_assignments: dict[int, float]):
self.value = value
self.variable_assignments = variable_assignments
def shared_variables_with(self, other: "Outcome") -> set[int]:
"""Return the set of variable IDs that both outcomes depend on."""
return set(self.variable_assignments) & set(other.variable_assignments)
def is_compatible_with(self, other: "Outcome") -> bool:
"""Check if outcomes are compatible (same values for any shared variables)."""
shared_vars = self.shared_variables_with(other)
return all(
self.variable_assignments[v] == other.variable_assignments[v]
for v in shared_vars
)
def combine_with(self, other: "Outcome", operation) -> "Outcome":
"""Combine this outcome with another using the given operation."""
if not self.is_compatible_with(other):
raise ValueError("Cannot combine incompatible outcomes")
merged_assignments = self.variable_assignments | other.variable_assignments
result_value = operation(self.value, other.value)
return Outcome(result_value, merged_assignments)
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Outcome)
and self.value == other.value
and self.variable_assignments == other.variable_assignments
)
def __hash__(self) -> int:
# Make the dict hashable by converting to sorted tuple
assignments_tuple = tuple(sorted(self.variable_assignments.items()))
return hash((self.value, assignments_tuple))
def __repr__(self) -> str:
return f"Outcome(value={self.value}, variables={self.variable_assignments})"
class RandomVariable:
"""A random variable with a discrete probability distribution.
Args:
name: Descriptive name for the random variable
distribution: Distribution object
"""
def __init__(self, name: str, distribution: Distribution):
self.name = name
if not distribution:
raise ValueError("Distribution cannot be empty")
# Distribution object - create outcomes
self.outcomes: dict[Outcome, Fraction] = {
Outcome(value, {id(self): value}): prob
for value, prob in distribution.items()
}
@classmethod
def _from_outcomes(
cls, name: str, outcomes: dict[Outcome, Fraction]
) -> "RandomVariable":
"""Internal constructor for creating RandomVariable from outcomes."""
instance = cls.__new__(cls)
instance.name = name
instance.outcomes = outcomes
return instance
def possible_values(self) -> list[float]:
"""Return sorted list of all possible values."""
return Distribution.from_outcomes(self.outcomes).possible_values()
@property
def expectation(self) -> float:
"""Calculate the expectation (mean) of the distribution."""
return Distribution.from_outcomes(self.outcomes).expectation()
@property
def variance(self) -> float:
"""Calculate the variance using Var(X) = E[X²] - E[X]²."""
return (self * self).expectation - self.expectation**2
@property
def std_dev(self) -> float:
"""Calculate the standard deviation as sqrt(variance)."""
return sqrt(self.variance)
def _compute_joint_probability(
self,
outcome1: Outcome,
prob1: Fraction,
outcome2: Outcome,
prob2: Fraction,
) -> Fraction:
"""Compute joint probability as the minimum of the constituent probabilities.
Mathematical insight: For compatible outcomes, the joint probability is
constrained by the most restrictive (lowest probability) component.
This elegantly handles both independent and dependent cases without conditionals.
"""
shared_vars = outcome1.shared_variables_with(outcome2)
# Use exact Fraction arithmetic
p1, p2 = prob1, prob2
# If independent, joint probability is the product
# If dependent, joint probability is the minimum (most restrictive probability)
return min(p1, p2) if shared_vars else p1 * p2
def _combine_outcomes(
self, other: "RandomVariable", operation, op_symbol: str
) -> "RandomVariable":
"""Helper method to combine outcomes from two random variables using proper joint probabilities."""
result_outcomes: dict[Outcome, Fraction] = defaultdict(Fraction)
for (outcome1, prob1), (outcome2, prob2) in itertools.product(
self.outcomes.items(), other.outcomes.items()
):
if outcome1.is_compatible_with(outcome2):
# Combine the compatible outcomes
new_outcome = outcome1.combine_with(outcome2, operation)
# Compute the correct joint probability
joint_prob = self._compute_joint_probability(
outcome1, prob1, outcome2, prob2
)
result_outcomes[new_outcome] += joint_prob
return RandomVariable._from_outcomes(
f"({self.name} {op_symbol} {other.name})", dict(result_outcomes)
)
def __add__(self, other: "RandomVariable") -> "RandomVariable":
"""Add two random variables.
Examples:
>>> die1 = RandomVariable.die()
>>> die2 = RandomVariable.die()
>>> sum_dice = die1 + die2
>>> sum_dice.possible_values()
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
>>> sum_dice.expectation == 7
True
>>> same_die = die1 + die1 # Same variable - dependent
>>> same_die.possible_values()
[2, 4, 6, 8, 10, 12]
"""
return self._combine_outcomes(other, operator.add, "+")
def __sub__(self, other: "RandomVariable") -> "RandomVariable":
"""Subtract two random variables."""
return self._combine_outcomes(other, operator.sub, "-")
def __mul__(self, other: "RandomVariable") -> "RandomVariable":
"""Multiply two random variables."""
return self._combine_outcomes(other, operator.mul, "*")
def __truediv__(self, other: "RandomVariable") -> "RandomVariable":
"""Divide two random variables."""
return self._combine_outcomes(other, operator.truediv, "/")
def __len__(self) -> int:
"""Return the number of possible values."""
return len(self.possible_values())
def __bool__(self) -> bool:
"""Return True if the random variable has possible outcomes."""
return bool(self.outcomes)
def probability_of(self, value: float) -> Fraction:
"""Return the probability of a specific value."""
return Distribution.from_outcomes(self.outcomes).probabilities.get(
value, Fraction(0)
)
@classmethod
def die(cls, sides: int = 6, name: str | None = None) -> "RandomVariable":
"""Create a fair die with the given number of sides.
Args:
sides: Number of sides (default: 6)
name: Name for the die (default: "d{sides}")
Examples:
>>> die = RandomVariable.die()
>>> die.possible_values()
[1, 2, 3, 4, 5, 6]
>>> die.expectation == 3.5
True
>>> d20 = RandomVariable.die(20)
>>> len(d20)
20
"""
name = name or f"d{sides}"
return cls(name, Distribution.uniform(range(1, sides + 1)))
@classmethod
def coin(cls, p: Fraction = Fraction(1, 2), name: str = "coin") -> "RandomVariable":
"""Create a coin flip with heads probability p.
Args:
p: Probability of heads (1) vs tails (0)
name: Name for the coin
Examples:
>>> coin = RandomVariable.coin()
>>> coin.possible_values()
[0, 1]
>>> coin.expectation
Fraction(1, 2)
>>> from fractions import Fraction
>>> biased = RandomVariable.coin(Fraction(7, 10))
>>> biased.probability_of(0)
Fraction(3, 10)
# Test exact probability calculations
>>> fair = RandomVariable.coin()
>>> fair.outcomes[list(fair.outcomes.keys())[0]] # Should be exact Fraction(1,2)
Fraction(1, 2)
"""
return cls(name, Distribution.bernoulli(p))
def __repr__(self) -> str:
# Convert outcomes back to simple distribution for display
distribution = Distribution.from_outcomes(self.outcomes)
return f"RandomVariable('{self.name}', {distribution.probabilities})"
if __name__ == "__main__":
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment