Skip to content

Instantly share code, notes, and snippets.

@alexsavio
Last active April 22, 2026 08:26
Show Gist options
  • Select an option

  • Save alexsavio/2f045ea80086cce85f32a5c8cf5f6ef5 to your computer and use it in GitHub Desktop.

Select an option

Save alexsavio/2f045ea80086cce85f32a5c8cf5f6ef5 to your computer and use it in GitHub Desktop.
SIGReg: Sliced Isotropic Gaussian Regularizer from LeWorldModel (Maes et al. 2026) — runnable PyTorch demo

SIGReg demo

Runnable PyTorch implementation of the Sliced Isotropic Gaussian Regularizer (SIGReg) from the LeWorldModel paper (Maes, Le Lidec, Scieur, LeCun, Balestriero, arXiv:2603.19312, March 2026).

SIGReg is a single-term regularizer that replaces the usual six-knob stack (EMA teachers, stop-gradient, frozen pretrained features, VICReg variance/covariance terms, momentum encoders, LR warmup schedules) used to prevent representation collapse in JEPA-style self-supervised world models.

Companion post: LeWorldModel: One Regularizer Instead of Six for Pixel JEPA.

How it works

For each batch of embeddings Z of shape (N, D):

  1. Draw M random unit directions on the hypersphere S^(D-1).
  2. Project Z onto each direction to get M one-dimensional slices.
  3. For every slice, compute the Epps-Pulley statistic: the L²-weighted distance between the slice's empirical characteristic function and the characteristic function of a standard normal, using numerical quadrature (trapezoid rule on t ∈ [0.2, 4]).
  4. Average across slices.

The Cramér-Wold theorem guarantees that matching all 1D projections is equivalent to matching the full joint distribution, so enforcing normality on enough random slices enforces N(0, I) globally.

Run it

uv run --with torch --with matplotlib python sigreg_demo.py

Writes sigreg_results.png with three panels: per-scenario scores, whitening-training curve, and a first-axis before/after histogram.

Results

SIGReg results

Baseline scenarios (N=2048, D=32)

Distribution SIGReg
Isotropic standard Gaussian N(0, I) 0.000281
Full collapse (all points identical) 0.246018
Rank-1 subspace (dimensional collapse) 0.081067
Student-t, df=3 (heavy tails) 0.047251
Anisotropic Gaussian (one big axis) 0.026291
Shifted Gaussian (mean=2) 0.480116

The isotropic Gaussian baseline is ~0.0003. Every failure mode that would sink a JEPA encoder jumps two to three orders of magnitude. Shifted mean is punished the hardest, which is the exact failure a predictor-only loss would miss.

Whitening by minimizing SIGReg alone

Input: anisotropic Gaussian with per-axis std from 0.2 to 5.0 and mean 1.5. Output: affine map trained for 300 Adam steps on SIGReg only (no reconstruction, no target network).

Step SIGReg
0 0.314595
60 0.002082
120 0.000989
180 0.000913
240 0.000919
300 0.000841

Final state: mean norm 0.079, per-axis std in [0.868, 1.066], SIGReg 0.000841 vs reference 0.000109.

License

MIT. Do what you want with it.

"""
SIGReg demo: Sliced Isotropic Gaussian Regularizer from LeWorldModel.
Implements the Epps-Pulley normality test via numerical quadrature of the
empirical characteristic function, following the LeWorldModel paper's
appendix (Maes, Le Lidec, Scieur, LeCun, Balestriero, arXiv 2603.19312v2).
Usage:
uv run --with torch --with matplotlib python sigreg_demo.py
Produces `sigreg_results.png` with three panels:
1. SIGReg score across collapse / failure modes
2. Training curve for SIGReg-only whitening
3. Marginal distribution before vs after whitening
"""
from __future__ import annotations
import matplotlib.pyplot as plt
import torch
torch.manual_seed(0)
def sigreg(
Z: torch.Tensor,
M: int = 1024,
t_nodes: torch.Tensor | None = None,
bandwidth: float = 1.0,
) -> torch.Tensor:
"""Sliced Epps-Pulley test statistic averaged over M random directions."""
if t_nodes is None:
t_nodes = torch.linspace(0.2, 4.0, 64, device=Z.device)
# 1. Directions uniform on the hypersphere S^(D-1)
U = torch.randn(Z.shape[1], M, device=Z.device)
U = U / U.norm(dim=0, keepdim=True)
# 2. Project to M 1D slices: (N, M)
H = Z @ U
# 3. Empirical characteristic function per slice at each quadrature node
tH = t_nodes[:, None, None] * H[None, :, :] # (T, N, M)
phi_emp = torch.exp(1j * tH).mean(dim=1) # (T, M)
# 4. Standard-normal CF and Gaussian weighting
phi_0 = torch.exp(-t_nodes ** 2 / 2)[:, None]
w = torch.exp(-t_nodes ** 2 / (2 * bandwidth ** 2))[:, None]
# 5. Trapezoid quadrature of |phi_emp - phi_0|^2 * w(t)
integrand = (phi_emp - phi_0).abs() ** 2 * w
return torch.trapezoid(integrand, t_nodes, dim=0).mean()
def baseline_scenarios() -> dict[str, float]:
N, D = 2048, 32
scales = torch.ones(D)
scales[0] = 5.0
direction = torch.randn(D)
cases = {
"Isotropic\nGaussian": torch.randn(N, D),
"Full\ncollapse": torch.zeros(N, D) + 0.3,
"Rank-1\nsubspace": torch.randn(N, 1) * direction,
"Student-t\n(df=3)": torch.distributions.StudentT(df=3).sample((N, D)),
"Anisotropic\nGaussian": torch.randn(N, D) * scales,
"Shifted\n(mean=2)": torch.randn(N, D) + 2.0,
}
return {name: sigreg(Z).item() for name, Z in cases.items()}
def whitening_run() -> tuple[list[float], torch.Tensor, torch.Tensor, float]:
N, D = 4096, 16
scales = torch.linspace(0.2, 5.0, D)
X = torch.randn(N, D) * scales + 1.5
W = torch.eye(D, requires_grad=True)
b = torch.zeros(D, requires_grad=True)
opt = torch.optim.Adam([W, b], lr=5e-2)
curve: list[float] = []
for _ in range(301):
Z = X @ W + b
loss = sigreg(Z, M=512)
opt.zero_grad()
loss.backward()
opt.step()
curve.append(loss.item())
with torch.no_grad():
Z_final = X @ W + b
ref = sigreg(torch.randn(N, D)).item()
return curve, X.detach(), Z_final.detach(), ref
def plot(results: dict[str, float], curve, X, Z_final, ref, path: str) -> None:
fig, axes = plt.subplots(1, 3, figsize=(15, 4.2), constrained_layout=True)
# Panel 1: Scenario scores (log scale)
ax = axes[0]
names = list(results.keys())
values = [results[n] for n in names]
bar_colors = [
"#2a9d8f" if "Isotropic" in n else "#e76f51"
for n in names
]
ax.bar(range(len(names)), values, color=bar_colors)
ax.set_xticks(range(len(names)))
ax.set_xticklabels(names, fontsize=9)
ax.set_yscale("log")
ax.set_ylabel("SIGReg (log)")
ax.set_title("Scores across failure modes")
ax.grid(True, axis="y", linestyle=":", alpha=0.4)
for i, v in enumerate(values):
ax.text(i, v * 1.15, f"{v:.4f}", ha="center", fontsize=8)
# Panel 2: Training curve
ax = axes[1]
ax.plot(curve, color="#264653", linewidth=1.6)
ax.axhline(ref, color="#2a9d8f", linestyle="--", linewidth=1.2,
label=f"N(0, I) reference = {ref:.5f}")
ax.set_yscale("log")
ax.set_xlabel("Optimization step")
ax.set_ylabel("SIGReg (log)")
ax.set_title("Whitening by minimizing SIGReg alone")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend(fontsize=9)
# Panel 3: Marginal histogram of one axis, before vs after
ax = axes[2]
axis = 0
bins = 60
ax.hist(X[:, axis].numpy(), bins=bins, alpha=0.55, label="Before", color="#e76f51",
density=True)
ax.hist(Z_final[:, axis].numpy(), bins=bins, alpha=0.55, label="After",
color="#264653", density=True)
# Overlay N(0,1) pdf
xs = torch.linspace(-6, 8, 300)
pdf = (1 / (2 * torch.pi) ** 0.5) * torch.exp(-xs ** 2 / 2)
ax.plot(xs.numpy(), pdf.numpy(), color="#2a9d8f", linewidth=1.6, label="N(0, 1)")
ax.set_xlabel("Axis 0 value")
ax.set_ylabel("Density")
ax.set_title("First-axis marginal (before / after)")
ax.legend(fontsize=9)
ax.grid(True, linestyle=":", alpha=0.4)
fig.suptitle(
"SIGReg: sliced Epps-Pulley normality regularizer (LeWorldModel, 2026)",
fontsize=12,
)
fig.savefig(path, dpi=140)
print(f"Wrote {path}")
def main() -> None:
print("## Baseline scenarios (N=2048, D=32)\n")
scores = baseline_scenarios()
for name, v in scores.items():
print(f" {name.replace(chr(10), ' '):30s} SIGReg = {v:.6f}")
print("\n## Whitening an anisotropic shifted Gaussian with SIGReg alone\n")
curve, X, Z_final, ref = whitening_run()
for step in (0, 60, 120, 180, 240, 300):
print(f" step {step:3d} SIGReg = {curve[step]:.6f}")
print(f"\n Gaussian reference SIGReg: {ref:.6f}")
print(f" Final optimized SIGReg: {curve[-1]:.6f}")
print(f" Final mean norm: {Z_final.mean(0).norm().item():.4f}")
print(f" Final per-axis std min/max: "
f"{Z_final.std(0).min().item():.3f} / {Z_final.std(0).max().item():.3f}")
plot(scores, curve, X, Z_final, ref, "sigreg_results.png")
plot(scores, curve, X, Z_final, ref, "sigreg_results.svg")
if __name__ == "__main__":
main()
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment