Created
March 26, 2022 11:40
-
-
Save WarrenWeckesser/636b537ee889679227d53543d333a720 to your computer and use it in GitHub Desktop.
Compute stats for the truncnorm distribution using mpmath
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mpmath | |
mpmath.mp.dps = 80 | |
def truncnorm_delta_cdf(a, b): | |
if a > 0: | |
delta = mpmath.ncdf(-a) - mpmath.ncdf(-b) | |
else: | |
delta = mpmath.ncdf(b) - mpmath.ncdf(a) | |
return delta | |
def truncnorm_pdf(x, a, b): | |
if a >= b: | |
raise ValueError("'a' must be less than 'b'") | |
delta_cdf = truncnorm_delta_cdf(a, b) | |
if delta_cdf == 0: | |
raise RuntimeError("delta_cdf is 0; try increasing mpmath.mp.dps.") | |
return mpmath.npdf(x) / delta_cdf | |
def truncnorm_stats(a, b): | |
a = mpmath.mpf(a) | |
b = mpmath.mpf(b) | |
pa = truncnorm_pdf(a, a, b) | |
pb = truncnorm_pdf(b, a, b) | |
# Fix multiplication of p(a)*a and p(b)*b when a or b | |
# is infinite by setting a or b to 0, resp. Otherwise | |
# the product 0*inf gives nan. | |
if b == mpmath.inf: | |
b = 0 | |
if a == -mpmath.inf: | |
a = 0 | |
# m# are moments about 0 (i.e. noncentral moments) | |
# mu# are moments about the mean (i.e. central moments) | |
m1 = pa - pb | |
mu = m1 | |
m2 = 1 + pa*a - pb*b | |
mu2 = (a - mu)*pa - (b - mu)*pb + 1 | |
m3 = 2*m1 + pa*a**2 - pb*b**2 | |
m4 = 3*m2 + pa*a**3 - pb*b**3 | |
mu3 = m3 + m1 * (-3*m2 + 2*m1**2) | |
g1 = mu3 / mpmath.power(mu2, 1.5) | |
mu4 = m4 + m1*(-4*m3 + 3*m1*(2*m2 - m1**2)) | |
g2 = mu4 / mu2**2 - 3 | |
return mu, mu2, g1, g2 | |
def truncnorm_stats_quad(a, b): | |
if mpmath.isinf(a) or mpmath.isinf(b): | |
raise ValueError("truncnorm_stats_quad requires both 'a' and 'b' " | |
"to be finite.") | |
# Use mpmath.quad to compute mu3 and mu4. | |
# This is slow, but it provides a check for mistakes in the | |
# implementation of truncnorm_stats(a, b). | |
a = mpmath.mpf(a) | |
b = mpmath.mpf(b) | |
pa = truncnorm_pdf(a, a, b) | |
pb = truncnorm_pdf(b, a, b) | |
# mu is the mean. | |
# mu# are moments about the mean (i.e. central moments). | |
mu = pa - pb | |
mu2 = (a - mu)*pa - (b - mu)*pb + 1 | |
mu3 = mpmath.quad(lambda t: truncnorm_pdf(t, a, b)*(t - mu)**3, [a, b]) | |
g1 = mu3 / mpmath.power(mu2, 1.5) | |
mu4 = mpmath.quad(lambda t: truncnorm_pdf(t, a, b)*(t - mu)**4, [a, b]) | |
g2 = mu4 / mu2**2 - 3 | |
return mu, mu2, g1, g2 | |
def print_table(intervals): | |
h = ['a', 'b', 'mean', 'variance', 'skewness', 'excess kurtosis'] | |
print(f"{h[0]:>6} {h[1]:>6} " | |
f"{h[2]:>24s} {h[3]:>24s} {h[4]:>24s} {h[5]:>24s}") | |
for a, b in intervals: | |
print(f"{str(a):>6} {str(b):>6}", end='') | |
for stat in [float(t) for t in truncnorm_stats(a, b)]: | |
print(f" {repr(stat):>24s}", end='') | |
print() | |
def parstr(value): | |
if mpmath.isinf(value): | |
sgn = '' if value >= 0 else '-' | |
text = sgn + "np.inf" | |
else: | |
text = str(value) | |
return text | |
def print_test_code(intervals): | |
print("# Test data for the truncnorm stats() method.") | |
print("# The data in each row is:") | |
print("# a, b, mean, variance, skewness, excess kurtosis.") | |
print("_truncnorm_stats_data = [") | |
for a, b in intervals: | |
s = f" [{parstr(a)}, {parstr(b)}," | |
print(s) | |
indent = " " | |
valstrs = [str(float(t)) for t in truncnorm_stats(a, b)] | |
line = indent + ', '.join(valstrs) + '],' | |
if len(line) > 79: | |
line = indent + (',\n' + indent).join(valstrs) + '],' | |
print(line) | |
print("]") | |
if __name__ == "__main__": | |
check_intervals = [ | |
[-30, 30], | |
[-10, 10], | |
[-3, 3], | |
[-2, 2], | |
[0, mpmath.inf], | |
[-mpmath.inf, 0], | |
[-1, 3], | |
[-3, 1], | |
[-10, -9], | |
[-20, -19], | |
[-30, -29], | |
[-40, -39], | |
[39, 40], | |
] | |
# print_test_code(check_intervals) | |
print_table(check_intervals) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment