|
""" |
|
SIGReg demo: Sliced Isotropic Gaussian Regularizer from LeWorldModel. |
|
|
|
Implements the Epps-Pulley normality test via numerical quadrature of the |
|
empirical characteristic function, following the LeWorldModel paper's |
|
appendix (Maes, Le Lidec, Scieur, LeCun, Balestriero, arXiv 2603.19312v2). |
|
|
|
Usage: |
|
uv run --with torch --with matplotlib python sigreg_demo.py |
|
|
|
Produces `sigreg_results.png` with three panels: |
|
1. SIGReg score across collapse / failure modes |
|
2. Training curve for SIGReg-only whitening |
|
3. Marginal distribution before vs after whitening |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
def sigreg( |
|
Z: torch.Tensor, |
|
M: int = 1024, |
|
t_nodes: torch.Tensor | None = None, |
|
bandwidth: float = 1.0, |
|
) -> torch.Tensor: |
|
"""Sliced Epps-Pulley test statistic averaged over M random directions.""" |
|
if t_nodes is None: |
|
t_nodes = torch.linspace(0.2, 4.0, 64, device=Z.device) |
|
|
|
# 1. Directions uniform on the hypersphere S^(D-1) |
|
U = torch.randn(Z.shape[1], M, device=Z.device) |
|
U = U / U.norm(dim=0, keepdim=True) |
|
|
|
# 2. Project to M 1D slices: (N, M) |
|
H = Z @ U |
|
|
|
# 3. Empirical characteristic function per slice at each quadrature node |
|
tH = t_nodes[:, None, None] * H[None, :, :] # (T, N, M) |
|
phi_emp = torch.exp(1j * tH).mean(dim=1) # (T, M) |
|
|
|
# 4. Standard-normal CF and Gaussian weighting |
|
phi_0 = torch.exp(-t_nodes ** 2 / 2)[:, None] |
|
w = torch.exp(-t_nodes ** 2 / (2 * bandwidth ** 2))[:, None] |
|
|
|
# 5. Trapezoid quadrature of |phi_emp - phi_0|^2 * w(t) |
|
integrand = (phi_emp - phi_0).abs() ** 2 * w |
|
return torch.trapezoid(integrand, t_nodes, dim=0).mean() |
|
|
|
|
|
def baseline_scenarios() -> dict[str, float]: |
|
N, D = 2048, 32 |
|
|
|
scales = torch.ones(D) |
|
scales[0] = 5.0 |
|
direction = torch.randn(D) |
|
|
|
cases = { |
|
"Isotropic\nGaussian": torch.randn(N, D), |
|
"Full\ncollapse": torch.zeros(N, D) + 0.3, |
|
"Rank-1\nsubspace": torch.randn(N, 1) * direction, |
|
"Student-t\n(df=3)": torch.distributions.StudentT(df=3).sample((N, D)), |
|
"Anisotropic\nGaussian": torch.randn(N, D) * scales, |
|
"Shifted\n(mean=2)": torch.randn(N, D) + 2.0, |
|
} |
|
return {name: sigreg(Z).item() for name, Z in cases.items()} |
|
|
|
|
|
def whitening_run() -> tuple[list[float], torch.Tensor, torch.Tensor, float]: |
|
N, D = 4096, 16 |
|
scales = torch.linspace(0.2, 5.0, D) |
|
X = torch.randn(N, D) * scales + 1.5 |
|
|
|
W = torch.eye(D, requires_grad=True) |
|
b = torch.zeros(D, requires_grad=True) |
|
opt = torch.optim.Adam([W, b], lr=5e-2) |
|
|
|
curve: list[float] = [] |
|
for _ in range(301): |
|
Z = X @ W + b |
|
loss = sigreg(Z, M=512) |
|
opt.zero_grad() |
|
loss.backward() |
|
opt.step() |
|
curve.append(loss.item()) |
|
|
|
with torch.no_grad(): |
|
Z_final = X @ W + b |
|
ref = sigreg(torch.randn(N, D)).item() |
|
return curve, X.detach(), Z_final.detach(), ref |
|
|
|
|
|
def plot(results: dict[str, float], curve, X, Z_final, ref, path: str) -> None: |
|
fig, axes = plt.subplots(1, 3, figsize=(15, 4.2), constrained_layout=True) |
|
|
|
# Panel 1: Scenario scores (log scale) |
|
ax = axes[0] |
|
names = list(results.keys()) |
|
values = [results[n] for n in names] |
|
bar_colors = [ |
|
"#2a9d8f" if "Isotropic" in n else "#e76f51" |
|
for n in names |
|
] |
|
ax.bar(range(len(names)), values, color=bar_colors) |
|
ax.set_xticks(range(len(names))) |
|
ax.set_xticklabels(names, fontsize=9) |
|
ax.set_yscale("log") |
|
ax.set_ylabel("SIGReg (log)") |
|
ax.set_title("Scores across failure modes") |
|
ax.grid(True, axis="y", linestyle=":", alpha=0.4) |
|
for i, v in enumerate(values): |
|
ax.text(i, v * 1.15, f"{v:.4f}", ha="center", fontsize=8) |
|
|
|
# Panel 2: Training curve |
|
ax = axes[1] |
|
ax.plot(curve, color="#264653", linewidth=1.6) |
|
ax.axhline(ref, color="#2a9d8f", linestyle="--", linewidth=1.2, |
|
label=f"N(0, I) reference = {ref:.5f}") |
|
ax.set_yscale("log") |
|
ax.set_xlabel("Optimization step") |
|
ax.set_ylabel("SIGReg (log)") |
|
ax.set_title("Whitening by minimizing SIGReg alone") |
|
ax.grid(True, linestyle=":", alpha=0.4) |
|
ax.legend(fontsize=9) |
|
|
|
# Panel 3: Marginal histogram of one axis, before vs after |
|
ax = axes[2] |
|
axis = 0 |
|
bins = 60 |
|
ax.hist(X[:, axis].numpy(), bins=bins, alpha=0.55, label="Before", color="#e76f51", |
|
density=True) |
|
ax.hist(Z_final[:, axis].numpy(), bins=bins, alpha=0.55, label="After", |
|
color="#264653", density=True) |
|
# Overlay N(0,1) pdf |
|
xs = torch.linspace(-6, 8, 300) |
|
pdf = (1 / (2 * torch.pi) ** 0.5) * torch.exp(-xs ** 2 / 2) |
|
ax.plot(xs.numpy(), pdf.numpy(), color="#2a9d8f", linewidth=1.6, label="N(0, 1)") |
|
ax.set_xlabel("Axis 0 value") |
|
ax.set_ylabel("Density") |
|
ax.set_title("First-axis marginal (before / after)") |
|
ax.legend(fontsize=9) |
|
ax.grid(True, linestyle=":", alpha=0.4) |
|
|
|
fig.suptitle( |
|
"SIGReg: sliced Epps-Pulley normality regularizer (LeWorldModel, 2026)", |
|
fontsize=12, |
|
) |
|
fig.savefig(path, dpi=140) |
|
print(f"Wrote {path}") |
|
|
|
|
|
def main() -> None: |
|
print("## Baseline scenarios (N=2048, D=32)\n") |
|
scores = baseline_scenarios() |
|
for name, v in scores.items(): |
|
print(f" {name.replace(chr(10), ' '):30s} SIGReg = {v:.6f}") |
|
|
|
print("\n## Whitening an anisotropic shifted Gaussian with SIGReg alone\n") |
|
curve, X, Z_final, ref = whitening_run() |
|
for step in (0, 60, 120, 180, 240, 300): |
|
print(f" step {step:3d} SIGReg = {curve[step]:.6f}") |
|
print(f"\n Gaussian reference SIGReg: {ref:.6f}") |
|
print(f" Final optimized SIGReg: {curve[-1]:.6f}") |
|
print(f" Final mean norm: {Z_final.mean(0).norm().item():.4f}") |
|
print(f" Final per-axis std min/max: " |
|
f"{Z_final.std(0).min().item():.3f} / {Z_final.std(0).max().item():.3f}") |
|
|
|
plot(scores, curve, X, Z_final, ref, "sigreg_results.png") |
|
plot(scores, curve, X, Z_final, ref, "sigreg_results.svg") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |