Last active
November 14, 2023 15:09
-
-
Save wassname/a9502f562d4d3e73729dc5b184db2501 to your computer and use it in GitHub Desktop.
Running stats (mean, standard deviation) for python, pytorch, etc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# handle pytorch tensors etc, by using tensorboardX's method | |
try: | |
from tensorboardX.x2num import make_np | |
except ImportError: | |
def make_np(x): | |
return np.array(x).copy().astype('float16') | |
class RunningStats(object): | |
"""Computes running mean and standard deviation | |
Url: https://gist.github.com/wassname/a9502f562d4d3e73729dc5b184db2501 | |
Adapted from: | |
* | |
<http://stackoverflow.com/questions/1174984/how-to-efficiently-\ | |
calculate-a-running-standard-deviation> | |
* <http://mathcentral.uregina.ca/QQ/database/QQ.09.02/carlos1.html> | |
* <https://gist.github.com/fvisin/5a10066258e43cf6acfa0a474fcdb59f> | |
Usage: | |
rs = RunningStats() | |
for i in range(10): | |
rs += np.random.randn() | |
print(rs) | |
print(rs.mean, rs.std) | |
""" | |
def __init__(self, n=0., m=None, s=None): | |
self.n = n | |
self.m = m | |
self.s = s | |
def clear(self): | |
self.n = 0. | |
def push(self, x, per_dim=False): | |
x = make_np(x) | |
# process input | |
if per_dim: | |
self.update_params(x) | |
else: | |
for el in x.flatten(): | |
self.update_params(el) | |
def update_params(self, x): | |
self.n += 1 | |
if self.n == 1: | |
self.m = x | |
self.s = 0. | |
else: | |
prev_m = self.m.copy() | |
self.m += (x - self.m) / self.n | |
self.s += (x - prev_m) * (x - self.m) | |
def __add__(self, other): | |
if isinstance(other, RunningStats): | |
sum_ns = self.n + other.n | |
prod_ns = self.n * other.n | |
delta2 = (other.m - self.m) ** 2. | |
return RunningStats(sum_ns, | |
(self.m * self.n + other.m * other.n) / sum_ns, | |
self.s + other.s + delta2 * prod_ns / sum_ns) | |
else: | |
self.push(other) | |
return self | |
@property | |
def mean(self): | |
return self.m if self.n else 0.0 | |
def variance(self): | |
return self.s / (self.n - 1) if self.n else 0.0 | |
@property | |
def std(self): | |
return np.sqrt(self.variance()) | |
def __repr__(self): | |
return '<RunningMean(mean={: 2.4f}, std={: 2.4f}, n={: 2f}, m={: 2.4f}, s={: 2.4f})>'.format(self.mean, self.std, self.n, self.m, self.s) | |
def __str__(self): | |
return 'mean={: 2.4f}, std={: 2.4f}'.format(self.mean, self.std) |
@DrSkippy true! I added it, thanks.
According to this source: https://www.johndcook.com/blog/standard_deviation/
in the calculation of the variance you should have self.s / (self.n -1 )
"For 2 ≤ k ≤ n, the kth estimate of the variance is s^2 = Sk/(k – 1)."
Thanks Tomer, I added that too.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The add for the case of combining two RunningStats should be something like this: