Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active July 19, 2025 18:59
Show Gist options
  • Save vankesteren/c651b75b5f0172fbd6b3f16c569b1409 to your computer and use it in GitHub Desktop.
Save vankesteren/c651b75b5f0172fbd6b3f16c569b1409 to your computer and use it in GitHub Desktop.
Tidy simulation example in python
import polars as pl
import numpy as np
from polarsgrid import expand_grid
from scipy.stats import norm, t, uniform, ttest_ind
from tqdm import tqdm
import plotnine as p9
grid = expand_grid(
# data generating process parameters
sample_size=[10, 20, 40, 80, 160, 320, 640],
distribution=["normal", "t", "uniform"],
effect_size=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
# method parameters
test=["welch", "student"],
# iterations
iter=list(range(500)),
)
def generate_data(
sample_size: int = 40, distribution: str = "normal", effect_size: float = 0.0
):
if distribution == "normal":
dist = norm()
if distribution == "t":
dist = t(1)
if distribution == "uniform":
dist = uniform()
return pl.DataFrame(
{
"group": ["treated"] * sample_size + ["control"] * sample_size,
"value": np.hstack(
[dist.rvs(sample_size), dist.rvs(sample_size) + effect_size]
),
}
)
def analyze_data(df: pl.DataFrame, test: str = "welch"):
eqvar = test == "student"
tstat, pval = ttest_ind(
df.filter(pl.col.group == "treated")["value"],
df.filter(pl.col.group == "control")["value"],
equal_var=eqvar,
)
return (tstat, pval)
results_table = []
for row in tqdm(grid.iter_rows(named=True), total=len(grid)):
df = generate_data(row["sample_size"], row["distribution"], row["effect_size"])
res = analyze_data(df, test=row["test"])
results_table.append(res)
results_df = pl.DataFrame(results_table, schema=["tstat", "pval"])
df = pl.concat([grid, results_df], how="horizontal")
df.write_parquet("results.parquet")
agg_df = df.group_by(["sample_size", "distribution", "effect_size", "test"]).agg(
(pl.col.pval < 0.05).mean().alias("power"),
pl.count("pval").alias("n")
)
plt = (
p9.ggplot(
agg_df.with_columns(pl.col.sample_size.cast(pl.String).cast(pl.Categorical)),
p9.aes(x="effect_size", y="power", color="sample_size", linetype="test")
)
+ p9.geom_point()
+ p9.geom_line()
+ p9.facet_grid(cols="distribution")
+ p9.theme_linedraw()
)
plt.save("result", width=12, height=8, dpi=300)
@vankesteren
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment