Last active
June 30, 2025 19:43
-
-
Save yoniLavi/54470b8ec9c17b21a75a8a1dff603019 to your computer and use it in GitHub Desktop.
Discrete probability distribution simulator from first principles, written with the help of Claude Sonnet 4
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Discrete probability distribution simulator with proper dependency tracking. | |
This module implements discrete probability distributions from first principles, | |
correctly handling complex dependencies between random variables. | |
Examples: | |
>>> die = RandomVariable.die() | |
>>> coin = RandomVariable.coin() | |
# Independent variables: full Cartesian product | |
>>> (die + coin).possible_values() | |
[1, 2, 3, 4, 5, 6, 7] | |
# Same variable: dependent operations | |
>>> (die - die).possible_values() # Always zero | |
[0] | |
# Dependency tracking: (X + X) + X correctly recognizes all three X's are the same | |
>>> from fractions import Fraction | |
>>> x = RandomVariable("x", Distribution({1: Fraction(1, 2), 2: Fraction(1, 2)})) | |
>>> ((x + x) + x).possible_values() # 3*1=3 or 3*2=6 | |
[3, 6] | |
# Test case that should work - mixed dependent/independent operations | |
>>> c = RandomVariable.coin() | |
>>> d = RandomVariable.die() | |
>>> result = c * (d + c) | |
>>> result.possible_values() # This should not raise ValueError | |
[0, 2, 3, 4, 5, 6, 7] | |
# Test exact arithmetic in operations | |
>>> die1 = RandomVariable.die() | |
>>> die2 = RandomVariable.die() | |
>>> sum_dice = die1 + die2 | |
>>> sum_dice.expectation == 7.0 # Exact comparison to float | |
True | |
# Test dependent operations preserve exact arithmetic | |
>>> same_die = die1 + die1 | |
>>> same_die.possible_values() # Should be the doubles of the originals | |
[2, 4, 6, 8, 10, 12] | |
>>> len(same_die) # Should have 6 outcomes, not 36 | |
6 | |
""" | |
from collections import defaultdict | |
from collections.abc import Collection, Iterator, ItemsView, KeysView, ValuesView | |
from fractions import Fraction | |
from math import isclose, sqrt | |
import itertools | |
import operator | |
class Distribution: | |
"""A probability distribution mapping values to their probabilities. | |
Examples: | |
>>> from fractions import Fraction | |
>>> dist = Distribution({1: Fraction(3, 10), 2: Fraction(7, 10)}) | |
>>> dist.expectation() # 1.7 | |
Fraction(17, 10) | |
>>> len(dist) | |
2 | |
>>> uniform = Distribution.uniform([1, 2, 3, 4]) | |
>>> uniform.expectation() == 2.5 | |
True | |
# Test exact arithmetic - probabilities sum to exactly 1 | |
>>> die = Distribution.uniform([1, 2, 3, 4, 5, 6]) | |
>>> sum(die.probabilities.values()) == 1 | |
True | |
>>> die.probabilities[1] | |
Fraction(1, 6) | |
# Test Bernoulli with exact fractions | |
>>> fair_coin = Distribution.bernoulli(Fraction(1, 2)) | |
>>> fair_coin.probabilities[0] + fair_coin.probabilities[1] == 1 | |
True | |
""" | |
def __init__(self, probabilities: dict[float, Fraction]): | |
# Store probabilities as exact Fractions | |
self.probabilities = probabilities | |
self.validate() | |
def validate(self) -> None: | |
"""Validate that probabilities sum to 1.""" | |
total_prob = sum(self.probabilities.values()) | |
if total_prob != 1: | |
raise ValueError(f"Probabilities must sum to 1, got {total_prob}") | |
def copy(self) -> "Distribution": | |
"""Create a copy of this distribution.""" | |
return Distribution(self.probabilities) | |
def possible_values(self) -> list[float]: | |
"""Return sorted list of all possible values.""" | |
return sorted(self.probabilities.keys()) | |
def expectation(self) -> float: | |
"""Calculate the expectation (mean) of the distribution.""" | |
return sum(value * prob for value, prob in self.probabilities.items()) | |
@classmethod | |
def from_outcomes(cls, outcomes: dict["Outcome", Fraction]) -> "Distribution": | |
"""Create a Distribution from a dict of Outcomes to probabilities.""" | |
probabilities: dict[float, Fraction] = defaultdict(Fraction) | |
for outcome, prob in outcomes.items(): | |
probabilities[outcome.value] += prob | |
return cls(probabilities) | |
@classmethod | |
def uniform(cls, values: Collection[float]) -> "Distribution": | |
"""Create a uniform distribution over the given values.""" | |
prob = Fraction(1, len(values)) | |
return cls({value: prob for value in values}) | |
@classmethod | |
def bernoulli(cls, p: Fraction) -> "Distribution": | |
"""Create a Bernoulli distribution with success probability p.""" | |
return cls({0: 1 - p, 1: p}) | |
@classmethod | |
def constant(cls, value: float) -> "Distribution": | |
"""Create a degenerate distribution (always returns the same value).""" | |
return cls({value: Fraction(1)}) | |
def __getitem__(self, key: float) -> Fraction: | |
return self.probabilities[key] | |
def __iter__(self) -> Iterator[float]: | |
return iter(self.probabilities) | |
def items(self) -> ItemsView[float, Fraction]: | |
return self.probabilities.items() | |
def values(self) -> ValuesView[Fraction]: | |
return self.probabilities.values() | |
def keys(self) -> KeysView[float]: | |
return self.probabilities.keys() | |
def __len__(self) -> int: | |
"""Return the number of possible values.""" | |
return len(self.probabilities) | |
def __bool__(self) -> bool: | |
"""Return True if distribution has values.""" | |
return bool(self.probabilities) | |
def __repr__(self) -> str: | |
return f"Distribution({self.probabilities})" | |
class Outcome: | |
""" | |
An outcome in probability theory - a specific realization from the sample space. | |
Each outcome consists of: | |
- value: The observed value (what the random variable maps to) | |
- variable_assignments: Which primitive variables took which values to produce this outcome | |
""" | |
def __init__(self, value: float, variable_assignments: dict[int, float]): | |
self.value = value | |
self.variable_assignments = variable_assignments | |
def shared_variables_with(self, other: "Outcome") -> set[int]: | |
"""Return the set of variable IDs that both outcomes depend on.""" | |
return set(self.variable_assignments) & set(other.variable_assignments) | |
def is_compatible_with(self, other: "Outcome") -> bool: | |
"""Check if outcomes are compatible (same values for any shared variables).""" | |
shared_vars = self.shared_variables_with(other) | |
return all( | |
self.variable_assignments[v] == other.variable_assignments[v] | |
for v in shared_vars | |
) | |
def combine_with(self, other: "Outcome", operation) -> "Outcome": | |
"""Combine this outcome with another using the given operation.""" | |
if not self.is_compatible_with(other): | |
raise ValueError("Cannot combine incompatible outcomes") | |
merged_assignments = self.variable_assignments | other.variable_assignments | |
result_value = operation(self.value, other.value) | |
return Outcome(result_value, merged_assignments) | |
def __eq__(self, other: object) -> bool: | |
return ( | |
isinstance(other, Outcome) | |
and self.value == other.value | |
and self.variable_assignments == other.variable_assignments | |
) | |
def __hash__(self) -> int: | |
# Make the dict hashable by converting to sorted tuple | |
assignments_tuple = tuple(sorted(self.variable_assignments.items())) | |
return hash((self.value, assignments_tuple)) | |
def __repr__(self) -> str: | |
return f"Outcome(value={self.value}, variables={self.variable_assignments})" | |
class RandomVariable: | |
"""A random variable with a discrete probability distribution. | |
Args: | |
name: Descriptive name for the random variable | |
distribution: Distribution object | |
""" | |
def __init__(self, name: str, distribution: Distribution): | |
self.name = name | |
if not distribution: | |
raise ValueError("Distribution cannot be empty") | |
# Distribution object - create outcomes | |
self.outcomes: dict[Outcome, Fraction] = { | |
Outcome(value, {id(self): value}): prob | |
for value, prob in distribution.items() | |
} | |
@classmethod | |
def _from_outcomes( | |
cls, name: str, outcomes: dict[Outcome, Fraction] | |
) -> "RandomVariable": | |
"""Internal constructor for creating RandomVariable from outcomes.""" | |
instance = cls.__new__(cls) | |
instance.name = name | |
instance.outcomes = outcomes | |
return instance | |
def possible_values(self) -> list[float]: | |
"""Return sorted list of all possible values.""" | |
return Distribution.from_outcomes(self.outcomes).possible_values() | |
@property | |
def expectation(self) -> float: | |
"""Calculate the expectation (mean) of the distribution.""" | |
return Distribution.from_outcomes(self.outcomes).expectation() | |
@property | |
def variance(self) -> float: | |
"""Calculate the variance using Var(X) = E[X²] - E[X]².""" | |
return (self * self).expectation - self.expectation**2 | |
@property | |
def std_dev(self) -> float: | |
"""Calculate the standard deviation as sqrt(variance).""" | |
return sqrt(self.variance) | |
def _compute_joint_probability( | |
self, | |
outcome1: Outcome, | |
prob1: Fraction, | |
outcome2: Outcome, | |
prob2: Fraction, | |
) -> Fraction: | |
"""Compute joint probability as the minimum of the constituent probabilities. | |
Mathematical insight: For compatible outcomes, the joint probability is | |
constrained by the most restrictive (lowest probability) component. | |
This elegantly handles both independent and dependent cases without conditionals. | |
""" | |
shared_vars = outcome1.shared_variables_with(outcome2) | |
# Use exact Fraction arithmetic | |
p1, p2 = prob1, prob2 | |
# If independent, joint probability is the product | |
# If dependent, joint probability is the minimum (most restrictive probability) | |
return min(p1, p2) if shared_vars else p1 * p2 | |
def _combine_outcomes( | |
self, other: "RandomVariable", operation, op_symbol: str | |
) -> "RandomVariable": | |
"""Helper method to combine outcomes from two random variables using proper joint probabilities.""" | |
result_outcomes: dict[Outcome, Fraction] = defaultdict(Fraction) | |
for (outcome1, prob1), (outcome2, prob2) in itertools.product( | |
self.outcomes.items(), other.outcomes.items() | |
): | |
if outcome1.is_compatible_with(outcome2): | |
# Combine the compatible outcomes | |
new_outcome = outcome1.combine_with(outcome2, operation) | |
# Compute the correct joint probability | |
joint_prob = self._compute_joint_probability( | |
outcome1, prob1, outcome2, prob2 | |
) | |
result_outcomes[new_outcome] += joint_prob | |
return RandomVariable._from_outcomes( | |
f"({self.name} {op_symbol} {other.name})", dict(result_outcomes) | |
) | |
def __add__(self, other: "RandomVariable") -> "RandomVariable": | |
"""Add two random variables. | |
Examples: | |
>>> die1 = RandomVariable.die() | |
>>> die2 = RandomVariable.die() | |
>>> sum_dice = die1 + die2 | |
>>> sum_dice.possible_values() | |
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] | |
>>> sum_dice.expectation == 7 | |
True | |
>>> same_die = die1 + die1 # Same variable - dependent | |
>>> same_die.possible_values() | |
[2, 4, 6, 8, 10, 12] | |
""" | |
return self._combine_outcomes(other, operator.add, "+") | |
def __sub__(self, other: "RandomVariable") -> "RandomVariable": | |
"""Subtract two random variables.""" | |
return self._combine_outcomes(other, operator.sub, "-") | |
def __mul__(self, other: "RandomVariable") -> "RandomVariable": | |
"""Multiply two random variables.""" | |
return self._combine_outcomes(other, operator.mul, "*") | |
def __truediv__(self, other: "RandomVariable") -> "RandomVariable": | |
"""Divide two random variables.""" | |
return self._combine_outcomes(other, operator.truediv, "/") | |
def __len__(self) -> int: | |
"""Return the number of possible values.""" | |
return len(self.possible_values()) | |
def __bool__(self) -> bool: | |
"""Return True if the random variable has possible outcomes.""" | |
return bool(self.outcomes) | |
def probability_of(self, value: float) -> Fraction: | |
"""Return the probability of a specific value.""" | |
return Distribution.from_outcomes(self.outcomes).probabilities.get( | |
value, Fraction(0) | |
) | |
@classmethod | |
def die(cls, sides: int = 6, name: str | None = None) -> "RandomVariable": | |
"""Create a fair die with the given number of sides. | |
Args: | |
sides: Number of sides (default: 6) | |
name: Name for the die (default: "d{sides}") | |
Examples: | |
>>> die = RandomVariable.die() | |
>>> die.possible_values() | |
[1, 2, 3, 4, 5, 6] | |
>>> die.expectation == 3.5 | |
True | |
>>> d20 = RandomVariable.die(20) | |
>>> len(d20) | |
20 | |
""" | |
name = name or f"d{sides}" | |
return cls(name, Distribution.uniform(range(1, sides + 1))) | |
@classmethod | |
def coin(cls, p: Fraction = Fraction(1, 2), name: str = "coin") -> "RandomVariable": | |
"""Create a coin flip with heads probability p. | |
Args: | |
p: Probability of heads (1) vs tails (0) | |
name: Name for the coin | |
Examples: | |
>>> coin = RandomVariable.coin() | |
>>> coin.possible_values() | |
[0, 1] | |
>>> coin.expectation | |
Fraction(1, 2) | |
>>> from fractions import Fraction | |
>>> biased = RandomVariable.coin(Fraction(7, 10)) | |
>>> biased.probability_of(0) | |
Fraction(3, 10) | |
# Test exact probability calculations | |
>>> fair = RandomVariable.coin() | |
>>> fair.outcomes[list(fair.outcomes.keys())[0]] # Should be exact Fraction(1,2) | |
Fraction(1, 2) | |
""" | |
return cls(name, Distribution.bernoulli(p)) | |
def __repr__(self) -> str: | |
# Convert outcomes back to simple distribution for display | |
distribution = Distribution.from_outcomes(self.outcomes) | |
return f"RandomVariable('{self.name}', {distribution.probabilities})" | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment