Last active
November 15, 2020 11:22
-
-
Save gubenkoved/d9876ccf3ceb935e81f45c8208931fa4 to your computer and use it in GitHub Desktop.
Empirical asymptotic complexity estimator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bisect | |
import inspect | |
import itertools | |
import inspect | |
import random | |
import time | |
import unittest | |
import numpy as np | |
from matplotlib import pyplot | |
from matplotlib import rcParams | |
from scipy.optimize import curve_fit | |
import scipy.stats | |
# set global settings | |
rcParams['font.family'] = 'Ubuntu' | |
class Options: | |
def __init__(self, | |
interpolation, | |
interpolation_total_points_count, | |
show_interpolated_points, | |
show_estimated_complexity, | |
show_linear_approximation, | |
debug): | |
self.interpolate = interpolation | |
self.interpolation_total_points_count = interpolation_total_points_count | |
self.show_interpolated_points = show_interpolated_points | |
self.show_estimated_complexity = show_estimated_complexity | |
self.show_linear_approximation = show_linear_approximation | |
self.debug = debug | |
options = Options( | |
interpolation=True, | |
interpolation_total_points_count=100, | |
show_interpolated_points=False, | |
show_estimated_complexity=True, | |
show_linear_approximation=False, | |
debug=True) | |
# TODO: Try to guess complexity based on the derivative of function, not the actual values | |
class AsymptoticTestsBase(unittest.TestCase): | |
def plot(self, figure, x, y, | |
xscale='linear', yscale='log', | |
rows=1, cols=1, idx=0, | |
extensions=None): | |
plot = figure.add_subplot(rows, cols, idx) | |
plot.set_xscale(xscale) | |
plot.set_yscale(yscale) | |
plot.set_xlabel('size ({})'.format(xscale)) | |
plot.set_ylabel('time ({})'.format(yscale)) | |
if extensions: | |
for ext_fn in extensions: | |
ext_fn(plot) | |
plot.plot(x, y, color='tab:red', linewidth=2, marker='o', markersize=5) | |
plot.set_ylim(ymin=0, ymax=np.max(y) * 1.1) | |
plot.grid(linestyle='dotted') | |
@staticmethod | |
def approximate_poly(plot, x, y, degree=1): | |
""" Adds linear approximation to the graph for the reference""" | |
p = np.polyfit(x, y, degree) | |
f = np.poly1d(p) | |
plot.plot(x, f(x), color='gray', linewidth=1, linestyle='dashed') | |
@staticmethod | |
def preprocess(points): | |
x = np.array([p[0] for p in points]) | |
y = np.array([max(0.0001, p[1]) for p in points]) | |
return x, y | |
@staticmethod | |
def interpolate(x, y): | |
points_count = options.interpolation_total_points_count | |
x_interpolated = np.linspace(np.min(x), np.max(x), points_count) | |
y_interpolated = np.interp(x_interpolated, x, y) | |
return x_interpolated, y_interpolated | |
@staticmethod | |
def get_test_method_name(): | |
for frame in inspect.stack(): | |
if frame.function.startswith('test_'): | |
return frame.function | |
return 'unknown' | |
def process(self, points, label=None): | |
cls = AsymptoticTestsBase | |
x, y = self.preprocess(points) | |
label = label or self.get_test_method_name() | |
extensions = [] | |
if options.show_linear_approximation: | |
def draw_linear_approximation(plot): | |
self.approximate_poly(plot, x, y, degree=1) | |
extensions.append(draw_linear_approximation) | |
if options.show_estimated_complexity: | |
x2, y2 = x, y | |
if options.interpolate: | |
x2, y2 = cls.interpolate(x, y) | |
if options.show_interpolated_points: | |
def draw_interpolated_points(plot): | |
plot.plot(x2, y2, color='red', marker='x', markersize=5, linestyle='none') | |
extensions.append(draw_interpolated_points) | |
best_fit_fn = cls.guess_complexity_curve_fit(x2, y2) | |
if best_fit_fn: | |
# fig.suptitle(best_fit_fn.label, fontsize=10, horizontalalignment='right') | |
# fig.text(1, 1, best_fit_fn.label, fontsize=10) | |
label += ', estimated: {0}'.format(best_fit_fn.label) | |
def extend_plot(plot): | |
plot.plot(x2, best_fit_fn(x2), color='blue', linewidth=1, linestyle='solid') | |
extensions.append(extend_plot) | |
fig = pyplot.figure(figsize=(10, 10)) | |
fig.suptitle(label, fontsize=10, fontweight='bold') | |
self.plot(fig, x, y, 'linear', 'linear', rows=2, cols=2, idx=1, extensions=extensions) | |
self.plot(fig, x, y, 'log', 'linear', rows=2, cols=2, idx=2, extensions=extensions) | |
self.plot(fig, x, y, 'linear', 'log', rows=2, cols=2, idx=3, extensions=extensions) | |
self.plot(fig, x, y, 'log', 'log', rows=2, cols=2, idx=4, extensions=extensions) | |
pyplot.show() | |
def measure(self, init_size, max_size, alpha, repetitions, init_fn, target_fn): | |
size = init_size | |
points = [] | |
while size <= max_size: | |
elapsed = 0 | |
for _ in range(repetitions): | |
init_data = init_fn(size) | |
start = time.time() | |
target_fn(*init_data) | |
elapsed += time.time() - start | |
size = int(alpha * size) | |
points.append((size, elapsed / repetitions)) | |
self.process(points) | |
@staticmethod | |
def guess_complexity_curve_fit(x, y): | |
def fn_const(x, c): | |
return 0 * x + c | |
def fn_linear(x, k, c): | |
return k * x + c | |
def fn_squared(x, k, c): | |
return k * x ** 2 + c | |
def fn_pow3(x, k, c): | |
return k * x ** 3 + c | |
def fn_log(x, k, c): | |
return k * np.log2(x) + c | |
def fn_nlogn(x, k, c): | |
return k * x * np.log2(x) + c | |
cls = AsymptoticTestsBase | |
library = [ | |
# TODO: How to handle constants efficiently? | |
(fn_const, 'O(1)'), | |
(fn_linear, 'O(n)'), | |
(fn_squared, 'O(n^2)'), | |
(fn_pow3, 'O(n^3)'), | |
(fn_log, 'O(logn)'), | |
(fn_nlogn, 'O(nlogn)'), | |
] | |
def resolve(f, coefficients): | |
def resolved(x): | |
return f(x, *coefficients) | |
return resolved | |
print('picking the best fit...') | |
best_fit = None | |
for idx, (fn, fn_label) in enumerate(library): | |
coefficients, _ = curve_fit(fn, x, y) | |
fn_solved = resolve(fn, coefficients) | |
fn_solved.label = fn_label | |
rss = cls.sum_of_squares_diff(x, y, fn_solved) | |
r, p = scipy.stats.pearsonr(y, fn_solved(x)) | |
coefficients_names = inspect.getfullargspec(fn).args[1:] | |
# this metric we are using to pick the best fit -- less the metric, better the fit | |
metric = rss | |
# edge case for constant approximation -- we have to pick it if it describes the | |
# data almost like other approximations, but we have to give it a "head start" | |
# as other more complex functions will always will otherwise with almost 0 | |
# x-dependent coefficients, basically being the constant approximation anyways | |
if fn == fn_const: | |
metric /= 2.0 | |
print('{name:11} RSS={rss:11.4e}, r={r:10f}, p={p:10.3e}, coefficients: {coeff}'.format( | |
name=fn.__name__, | |
rss=rss, | |
coeff=', '.join(['{c_name}={c_val:10.3e}'.format(c_name=c_name, c_val=c_val) | |
for c_name, c_val in zip(coefficients_names, coefficients)]), | |
r=r, | |
p=p, | |
)) | |
if not best_fit or metric < best_fit[1]: | |
best_fit = (fn_solved, metric) | |
return best_fit[0] | |
@staticmethod | |
def guess_complexity_legacy(x, y): | |
cls = AsymptoticTestsBase | |
# higher the degree of poly -- the less the sum of squared diff... | |
# so this approach does not work too well (may be we should discard lower | |
# degree coefficients and basically approximate as f(x) = c + a*x^n) | |
# good wiki on topic: | |
# https://en.wikipedia.org/wiki/Curve_fitting | |
# https://en.wikipedia.org/wiki/Polynomial_interpolation | |
approximations = [ | |
('O(n)', cls.poly_approximation(x, y, 1)), | |
('O(n^2)', cls.poly_approximation(x, y, 2)), | |
('O(n^3)', cls.poly_approximation(x, y, 3)), | |
('O(n^4)', cls.poly_approximation(x, y, 4)), | |
] | |
for label, f in approximations: | |
# draw the approximation | |
# plot.plot(x, f(x), color='tab:gray', linewidth=1, linestyle='dashed') | |
squared_diffs = cls.sum_of_squares_diff(x, y, f) | |
print('{0:7} {1:f}'.format(label, squared_diffs)) | |
@staticmethod | |
def poly_approximation(x, y, degree=1): | |
p = np.polyfit(x, y, degree) | |
f = np.poly1d(p) | |
return f | |
@staticmethod | |
def sum_of_squares_diff(x, y, f): | |
return np.sum((y - f(x)) ** 2) | |
class AsymptoticTests(AsymptoticTestsBase): | |
def test_sorting_sorted_array(self): | |
self.measure( | |
init_size=10 ** 5, | |
max_size=5 * 10 ** 7, | |
alpha=1.5, | |
repetitions=5, | |
init_fn=lambda n: (list(range(n)),), | |
target_fn=lambda array: array.sort()) | |
def test_sorting_array(self): | |
self.measure( | |
init_size=10 ** 5, | |
max_size=2 * 10 ** 7, | |
alpha=1.5, | |
repetitions=1, | |
init_fn=lambda n: ([random.randrange(0, 10**7) for _ in range(n)],), | |
target_fn=lambda array: array.sort()) | |
def test_bisect(self): | |
def init(n): | |
array = list(range(n)) | |
to_search = [n // 2 for _ in range(500000)] | |
return array, to_search | |
def target(array, to_search): | |
for x in to_search: | |
bisect.bisect_left(array, x) | |
self.measure( | |
init_size=1000, | |
max_size=10 ** 7, | |
alpha=2, | |
repetitions=3, | |
init_fn=init, | |
target_fn=target) | |
def test_cartesian_product(self): | |
self.measure( | |
init_size=500, | |
max_size=10 ** 4, | |
alpha=1.5, | |
repetitions=1, | |
init_fn=lambda n: (list(range(n)),), | |
target_fn=lambda x: list(itertools.product(x, x))) | |
def test_cartesian_product_pow3(self): | |
# illustrates O(n^3) complexity | |
self.measure( | |
init_size=30, | |
max_size=500, | |
alpha=1.4, | |
repetitions=1, | |
init_fn=lambda n: (list(range(n)),), | |
target_fn=lambda x: list(itertools.product(x, itertools.product(x, x)))) | |
def test_access_array_by_index_random_order(self): | |
def init(n): | |
array = list(range(n)) | |
to_search = [random.randrange(0, n) for _ in range(10 ** 5)] | |
return array, to_search | |
def target(array, to_search): | |
dummy = 0 | |
for x in to_search: | |
dummy += array[x] | |
self.measure( | |
init_size=1000, | |
max_size=3 * 10 ** 7, | |
alpha=1.7, | |
repetitions=10, | |
init_fn=init, | |
target_fn=target) | |
def test_access_array_by_index_sequential(self): | |
def init(n): | |
array = list(range(n)) | |
to_search = [x % n for x in range(10 ** 5)] | |
return array, to_search | |
def target(array, to_search): | |
dummy = 0 | |
for x in to_search: | |
dummy += array[x] | |
self.measure( | |
init_size=1000, | |
max_size=3 * 10 ** 7, | |
alpha=1.7, | |
repetitions=10, | |
init_fn=init, | |
target_fn=target) | |
# sanity testing... | |
def test_sleep_100ms(self): | |
self.measure( | |
init_size=10, | |
max_size=10 ** 6, | |
alpha=2, | |
repetitions=10, | |
init_fn=lambda _: (None, ), | |
target_fn=lambda _: time.sleep(0.100)) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
numpy==1.19.3 | |
matplotlib~=3.3.2 | |
scipy~=1.5.4 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment