Skip to content

Instantly share code, notes, and snippets.

@tai
Created July 11, 2024 02:13
Show Gist options
  • Save tai/2c8a15f660fa88ea98c38e4076ee0f34 to your computer and use it in GitHub Desktop.
Save tai/2c8a15f660fa88ea98c38e4076ee0f34 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
#
# POC of T-Digest algorithm based distributed processing
#
from tdigest import TDigest
from multiprocessing import Pool
import numpy as np
import time
import functools
## test parameter
batch_count = 1024
batch_size = 128
######################################################################
## quick benchmarking wrapper
def timeit(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time = time.time() - start_time
print('function [{}] finished in {} ms'.format(
func.__name__, int(elapsed_time * 1_000)))
return result
return new_func
## sample generator
class Sampler(object):
def __init__(self, batch_size):
self.batch_size = batch_size
np.random.seed(0)
# return (nr * batch_size) samples
def sample(self, nr=1):
return np.random.random(nr * self.batch_size)
######################################################################
## truth
@timeit
def np_median(sampler):
alldata = sampler.sample(batch_count)
return np.median(alldata)
print("np.median (true median):", np_median(Sampler(batch_size)))
## TD with a single batch
@timeit
def td_at_once(sampler):
alldata = sampler.sample(batch_count)
td = TDigest()
td.batch_update(alldata)
return td.percentile(50)
print("td50 of alldata:", td_at_once(Sampler(batch_size)))
## TD with batched updates
@timeit
def td_batched(sampler):
td = TDigest()
for i in range(batch_count):
td.batch_update(sampler.sample())
return td.percentile(50)
print("td50 of batched updates:", td_batched(Sampler(batch_size)))
## TD with distributed updates
@timeit
def td_distributed(sampler):
td = TDigest()
for i in range(batch_count):
td2 = TDigest()
td2.batch_update(sampler.sample())
td = td + td2
return td.percentile(50)
print("td50 of distributed updates:", td_distributed(Sampler(batch_size)))
## TD with parallel distributed updates
def do_one_batch(sampler):
td = TDigest()
td.batch_update(sampler.sample())
return td
def add_tds(tds):
td = tds[0]
for i in tds[1:]:
td = td + i
return td
@timeit
def td_distributed_parallel(sampler):
tds = Pool(32).map(do_one_batch, (sampler, ) * batch_count)
while len(tds) > 1:
tds = Pool(32).map(add_tds, zip(tds[0::2], tds[1::2]))
return tds[0].percentile(50)
print("td50 of parallel distributed updates:",
td_distributed_parallel(Sampler(batch_size)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment