Last active
March 8, 2024 20:56
-
-
Save camriddell/886d5cf0f268c88bf6955fe692f5281f to your computer and use it in GitHub Desktop.
A collection of univariate plots
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from textwrap import fill | |
from scipy.stats import norm, uniform, skewnorm, gaussian_kde, triang | |
from numpy import ( | |
array, linspace, quantile, histogram, atleast_2d, mean, std, add | |
) | |
from numpy.lib.stride_tricks import sliding_window_view | |
from matplotlib.pyplot import subplots, show, rc | |
from matplotlib.axes import Axes | |
import seaborn as sns | |
rc('font', size=14) | |
rc('axes.spines', top=False, right=False, left=False, bottom=False) | |
dists = [ | |
norm(loc=10, scale=2), | |
uniform(loc=0, scale=20), | |
skewnorm(a=6, loc=10, scale=2), | |
triang(c=1, loc=5, scale=7), | |
] | |
samples = [d.rvs(size=200, random_state=0) for d in dists] | |
def tufte_quartiles(ax, data): | |
q = quantile(data, [0, .25, .5, .75, 1]) | |
ax.hlines([0, 0], [q[0], q[3]], [q[1], q[4]]) | |
ax.scatter([q[2]], [0]) | |
def color_density(ax, data): | |
grid = linspace(data.min(), data.max(), 400) | |
densities = gaussian_kde(data)(grid) | |
densities = atleast_2d(densities).repeat(2, axis=0) | |
ax.pcolormesh(grid, [0, 1], densities, cmap='Blues') | |
def point_decile(ax, data): | |
d = quantile(data, linspace(0, 1, 11)) | |
linewidths = array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1]) * 4 | |
bounds = sliding_window_view(d, 2) | |
ax.hlines( | |
[0] * len(bounds), bounds[:, 0], bounds[:, 1], linewidths=linewidths | |
) | |
ax.scatter(d[5], 0, color='white', zorder=7) | |
def point_multi_sigmas(ax, data): | |
linewidths = array([1, 2, 3, 2, 1]) * 5 | |
avg, sd = mean(data), std(data) | |
sigmas = (sd * array([-3, -2, -1, 1, 2, 3])) | |
bounds = sliding_window_view(sigmas + avg, 2) | |
ax.hlines( | |
[0] * len(bounds), bounds[:, 0], bounds[:, 1], linewidths=linewidths | |
) | |
ax.scatter(avg, 0, color='white', zorder=7) | |
univariate_funcs = [ | |
('strip', partial(sns.stripplot, jitter=.3, ec='white', size=3)), | |
('swarm', partial(sns.swarmplot, size=3)), | |
('rug', partial(Axes.eventplot, alpha=.4)), | |
('kernel density (area)', partial(sns.kdeplot, fill=True)), | |
('kernel density (color)', color_density), | |
('cumulative KDE', partial(sns.kdeplot, cumulative=True)), | |
('empirical CDF', sns.ecdfplot), | |
('histogram', partial(sns.histplot, bins='auto')), | |
('Box', sns.boxplot), | |
('Boxen', sns.boxenplot), | |
('Tufte Quartile', tufte_quartiles), | |
(r'Point $\bar{x}\pm\sigma$', partial(sns.pointplot, orient='h', errorbar='sd')), | |
('Point Deciles', point_decile), | |
(r'Point $\bar{x}\pm$ 3$\sigma$,2$\sigma$,1$\sigma$', point_multi_sigmas), | |
] | |
gridspec_kw = dict(hspace=.1, wspace=.02, left=.15, right=.9, bottom=.05) | |
fig, axes = subplots( | |
len(univariate_funcs) + 1, len(dists), | |
sharey='row', sharex='col', | |
figsize=(16, 12), gridspec_kw=gridspec_kw, | |
dpi=106 | |
) | |
for ax, d in zip(axes[0], dists): | |
grid = linspace(*d.ppf([.001, .999]), 400) | |
y = d.pdf(grid) | |
ax.plot(grid, y) | |
ax.fill_between(grid, y, alpha=.4) | |
ax.set_title( | |
f"{d.dist.name.title()}\n" | |
f"{', '.join('='.join(map(str, t)) for t in d.kwds.items())}" | |
) | |
for i, (name, func) in enumerate(univariate_funcs, start=1): | |
if isinstance(func, partial): | |
func, args, kwargs = func.func, func.args, func.keywords | |
else: | |
args, kwargs = tuple(), {} | |
for j, s in enumerate(samples): | |
ax = axes[i, j] | |
package, _, _ = func.__module__.partition('.') | |
if package == 'seaborn': | |
func(x=s, ax=ax, **kwargs) | |
else: | |
func(ax, s, *args , **kwargs) | |
if ax in axes[:, 0]: | |
name = ' '.join(n if n.isupper() else n.capitalize() for n in name.split()) | |
name = fill(name, width=20, break_long_words=False) | |
ax.set_ylabel(name, rotation=0, ha='right', va='center') | |
for ax in axes.flat: | |
ax.yaxis.set_tick_params(length=0, width=0, labelleft=False) | |
for ax in axes[:-1, :].flat: | |
ax.xaxis.set_tick_params(length=0, width=0, labelbottom=False) | |
header_bbox = axes[0, 0].get_position() | |
row_bbox = axes[1, 0].get_position() | |
from matplotlib.lines import Line2D | |
sepline = Line2D( | |
[.1, .9], [(header_bbox.y0 - row_bbox.y1) / 2 +row_bbox.y1] * 2, | |
color='k' | |
) | |
fig.add_artist(sepline) | |
gs = fig.axes[0].get_gridspec() | |
centered = (gs.right - gs.left) / 2 + gs.left | |
fig.text( | |
x=centered, y=.98, s='A Collection of Univariate Plots', | |
fontsize='xx-large', va='top', ha='center' | |
) | |
# show() | |
fig.savefig('univariateplots.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment