Last active
May 24, 2024 15:06
-
-
Save camriddell/cf787bc25296caf8bfd83bb7c915cfc8 to your computer and use it in GitHub Desktop.
Create a bump chart in Python using matplotlib
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import MultipleLocator | |
def sigmoid(xs, ys, smooth=8, n=100): | |
"""Interpolates sigmoid function between x & y coordinates | |
Parameters | |
---------- | |
xs, ys: ndarray[2] | |
arrays must be of shape (2,N) where xs[0] are each of the starting | |
positions for the x-values and xs[1] are the stopping points. | |
smooth: int | |
steepness of sigmoid function slope. Doesn’t look great for values less than 8. | |
n: int | |
number of points to interpolate to. | |
Returns | |
------- | |
tuple (xs, ys) | |
The smoothed & interpolated x/y values evaluated across the inputs. | |
""" | |
(x_from, x_to), (y_from, y_to) = xs, ys | |
xs = np.linspace(-smooth, smooth, num=n)[:, None] | |
ys = np.exp(xs) / (np.exp(xs) + 1) | |
return ( | |
((xs + smooth) / (smooth * 2) * (x_to - x_from) + x_from), | |
(ys * (y_to - y_from) + y_from) | |
) | |
def sigmoid_pairwise(xs, ys, smooth=8, n=100): | |
"""Interpolates sigmoid function between every pair of xs & ys. | |
xs = [0, 1, 2, 3] | |
ys = [2, 5, 3, 7] | |
will interpolate between: | |
- [0, 1], [2, 5] | |
- [1, 2], [5, 3] | |
- [2, 3], [3, 7] | |
Parameters | |
---------- | |
xs, ys: array_like | |
Both inputs should be 1-d arrays and have identical length. | |
smooth: int | |
see sigmoid func | |
n: int | |
see sigmoid func | |
Returns | |
------- | |
tuple (xs, ys) | |
The smoothed & interpolated x/y values evaluated across the inputs. | |
""" | |
xs = np.lib.stride_tricks.sliding_window_view(xs, 2) | |
ys = np.lib.stride_tricks.sliding_window_view(ys, 2) | |
interp_x, interp_y = sigmoid(xs.T, ys.T, smooth=smooth, n=n) | |
return interp_x.T.flat, interp_y.T.flat | |
df = pd.DataFrame({ | |
'year': [*range(2019, 2022)] * 3, | |
'company': np.repeat(['A', 'B', 'C'], 3), | |
'revenue': [100, 200, 300, 150, 250, 100, 200, 300, 400], | |
}) | |
plt.rc('font', size=14) | |
plt.rc('axes.spines', right=False, top=False) | |
fig, ax = plt.subplots(figsize=(9, 6)) | |
for company, group in df.groupby('company'): | |
group = group.sort_values('year') | |
xs, ys = group['year'], group['revenue'] | |
interp_x, interp_y = sigmoid_pairwise(xs, ys) | |
line, = ax.plot(interp_x, interp_y, lw=3) | |
ax.scatter(xs, ys, s=100, color=line.get_color()) | |
text = ax.annotate( | |
f'Company {company.title()}', | |
xy=(1, ys.iloc[-1]), xycoords=(ax.transAxes, ax.transData), | |
xytext=(5, 0), textcoords='offset points', | |
color=line.get_color(), | |
va='center', | |
) | |
ax.xaxis.set_major_locator(MultipleLocator(1)) | |
ax.set_ylabel('Revenue') | |
ax.margins(x=.02) | |
ax.set_title('Revenue Bump Chart', size='x-large') | |
fig.tight_layout() | |
fig.savefig('bumpchart.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment