Created
July 11, 2024 02:13
-
-
Save tai/2c8a15f660fa88ea98c38e4076ee0f34 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# | |
# POC of T-Digest algorithm based distributed processing | |
# | |
from tdigest import TDigest | |
from multiprocessing import Pool | |
import numpy as np | |
import time | |
import functools | |
## test parameter | |
batch_count = 1024 | |
batch_size = 128 | |
###################################################################### | |
## quick benchmarking wrapper | |
def timeit(func): | |
@functools.wraps(func) | |
def new_func(*args, **kwargs): | |
start_time = time.time() | |
result = func(*args, **kwargs) | |
elapsed_time = time.time() - start_time | |
print('function [{}] finished in {} ms'.format( | |
func.__name__, int(elapsed_time * 1_000))) | |
return result | |
return new_func | |
## sample generator | |
class Sampler(object): | |
def __init__(self, batch_size): | |
self.batch_size = batch_size | |
np.random.seed(0) | |
# return (nr * batch_size) samples | |
def sample(self, nr=1): | |
return np.random.random(nr * self.batch_size) | |
###################################################################### | |
## truth | |
@timeit | |
def np_median(sampler): | |
alldata = sampler.sample(batch_count) | |
return np.median(alldata) | |
print("np.median (true median):", np_median(Sampler(batch_size))) | |
## TD with a single batch | |
@timeit | |
def td_at_once(sampler): | |
alldata = sampler.sample(batch_count) | |
td = TDigest() | |
td.batch_update(alldata) | |
return td.percentile(50) | |
print("td50 of alldata:", td_at_once(Sampler(batch_size))) | |
## TD with batched updates | |
@timeit | |
def td_batched(sampler): | |
td = TDigest() | |
for i in range(batch_count): | |
td.batch_update(sampler.sample()) | |
return td.percentile(50) | |
print("td50 of batched updates:", td_batched(Sampler(batch_size))) | |
## TD with distributed updates | |
@timeit | |
def td_distributed(sampler): | |
td = TDigest() | |
for i in range(batch_count): | |
td2 = TDigest() | |
td2.batch_update(sampler.sample()) | |
td = td + td2 | |
return td.percentile(50) | |
print("td50 of distributed updates:", td_distributed(Sampler(batch_size))) | |
## TD with parallel distributed updates | |
def do_one_batch(sampler): | |
td = TDigest() | |
td.batch_update(sampler.sample()) | |
return td | |
def add_tds(tds): | |
td = tds[0] | |
for i in tds[1:]: | |
td = td + i | |
return td | |
@timeit | |
def td_distributed_parallel(sampler): | |
tds = Pool(32).map(do_one_batch, (sampler, ) * batch_count) | |
while len(tds) > 1: | |
tds = Pool(32).map(add_tds, zip(tds[0::2], tds[1::2])) | |
return tds[0].percentile(50) | |
print("td50 of parallel distributed updates:", | |
td_distributed_parallel(Sampler(batch_size))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment