Created
July 31, 2022 00:53
-
-
Save devmotion/a6c3561f6c593160744147e7c5165f62 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* One-pass algorithm of `log_sum_exp`. | |
*/ | |
function log_sum_exp_onepass(x:Real[_]) -> Real { | |
if length(x) > 0 { | |
let (mx, r) <- transform_reduce(x, (-inf, 0.0), | |
\(x:(Real, Real), y:(Real, Real)) -> { | |
let (xa, xb) <- x; | |
let (ya, yb) <- y; | |
if xa > ya { | |
return (xa, xb + (yb + 1.0) * nan_exp(ya - xa)); | |
} else { | |
return (ya, yb + (xb + 1.0) * nan_exp(xa - ya)); | |
} | |
}, | |
\(x:Real) -> { | |
return (x, 0.0); | |
}); | |
return mx + log1p(r); | |
} else { | |
return -inf; | |
} | |
} | |
/* | |
* One-pass algorithm of `resample_reduce`. | |
*/ | |
function resample_reduce_onepass(w:Real[_]) -> (Real, Real) { | |
if length(w) == 0 { | |
return (0.0, 0.0); | |
} else { | |
let (mw, r, rsq) <- transform_reduce(w, (-inf, 0.0, 0.0), | |
\(x:(Real, Real, Real), y:(Real, Real, Real)) -> { | |
let (xa, xb, xc) <- x; | |
let (ya, yb, yc) <- y; | |
v:Real; | |
if xa > ya { | |
v <- nan_exp(ya - xa); | |
return (xa, xb + (yb + 1.0)*v, xc + (yc + 1.0)*v*v); | |
} else { | |
v <- nan_exp(xa - ya); | |
return (ya, yb + (xb + 1.0)*v, yc + (xc + 1.0)*v*v); | |
} | |
}, | |
\(x:Real) -> { | |
return (x, 0.0, 0.0); | |
}); | |
let rp1 <- r + 1.0; | |
let ess <- rp1*rp1/(rsq + 1.0); | |
let log_sum_weights <- mw + log1p(r); | |
return (ess, log_sum_weights); | |
} | |
} | |
/* | |
* Print scalar. | |
*/ | |
function print_result(x:Real) { | |
stdout.print(x); | |
} | |
/* | |
* Print tuple of scalars. | |
*/ | |
function print_result(x:(Real, Real)) { | |
let (xa, xb) <- x; | |
stdout.print("("); | |
stdout.print(xa); | |
stdout.print(", "); | |
stdout.print(xb); | |
stdout.print(")"); | |
} | |
// Underflow example | |
function underflow_example<F>(f:F) { | |
x:Real[_] <- [1e-20, log(1e-20)]; | |
stdout.print("f([1e-20, log(1e-20)]) = "); | |
print_result(f(x)); | |
} | |
program log_sum_exp_underflow(onepass:Boolean) { | |
if onepass { | |
stdout.print("f: log_sum_exp_onepass\n"); | |
underflow_example(log_sum_exp_onepass); | |
} else { | |
stdout.print("f: log_sum_exp\n"); | |
underflow_example(log_sum_exp); | |
} | |
stdout.print(" (correct: ~1.999999999999999999985e-20)\n"); | |
} | |
program resample_reduce_underflow(onepass:Boolean) { | |
if onepass { | |
stdout.print("f: resample_reduce_onepass\n"); | |
underflow_example(resample_reduce_onepass); | |
} else { | |
stdout.print("f: resample_reduce\n"); | |
underflow_example(resample_reduce); | |
} | |
stdout.print(" (correct: (_, ~1.999999999999999999985e-20))\n"); | |
} | |
// Timings | |
function timings<F,G>(f:F, g:G) { | |
x:Real[1000]; | |
for t in 1..1000 { | |
x[t] <~ Gaussian(0.0, 1.0); | |
} | |
tic(); | |
let y <- f(x); | |
let elapsed <- toc(); | |
stdout.print("current: "); | |
print_result(y); | |
stdout.print(" (result), "); | |
stdout.print(elapsed); | |
stdout.print(" (time)\n"); | |
tic(); | |
y <- g(x); | |
elapsed <- toc(); | |
stdout.print("onepass: "); | |
print_result(y); | |
stdout.print(" (result), "); | |
stdout.print(elapsed); | |
stdout.print(" (time)\n"); | |
} | |
program log_sum_exp_timings() { | |
timings(log_sum_exp, log_sum_exp_onepass); | |
} | |
program resample_reduce_timings() { | |
timings(resample_reduce, resample_reduce_onepass); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment