Skip to content

Instantly share code, notes, and snippets.

@ssghost
Forked from alexander-soare/epps_pulley_statistic.py
Last active April 4, 2026 21:20
Show Gist options
  • Select an option

  • Save ssghost/706ba37a8100e8080e83bf565e82de07 to your computer and use it in GitHub Desktop.

Select an option

Save ssghost/706ba37a8100e8080e83bf565e82de07 to your computer and use it in GitHub Desktop.
Visualizing optimzation of the SigReg loss from LeJEPA
import time
import torch
import matplotlib.pyplot as plt
D = 2 # dimension
N = 8 # number of projections
B = 512 # batch size
K = 17 # knots
U = 3 # upper bound of fourier domain
# Choose a starting distribution
# X = torch.normal(0.0, 0.5, size=(B, D))
X = torch.rand(size=(B, D)) * 2 - 1
X = torch.nn.Parameter(X.clone())
t = torch.linspace(0, U, K)
dt = U / (K - 1)
target = (-t.pow(2) / 2.0).exp()
weights = torch.full((K,), 2 * dt) * target
weights[[0, -1]] = dt
step = 0
def plot(t, X, re_fX, im_fX, target):
fig = plt.figure(figsize=(12, 8))
gs = fig.add_gridspec(D, 2, width_ratios=[0.8, 1.2], hspace=0.4, wspace=0.3)
for i in range(D):
ax = fig.add_subplot(gs[i, 0])
ax.plot(t, re_fX[i], "o", label="real")
ax.plot(t, im_fX[i], "o", label="imag")
ax.plot(t, target, color="red", label="targ")
ax.legend()
# Scatter plot spanning all rows in the right column
scatter_ax = fig.add_subplot(gs[:, 1], aspect="equal")
scatter_ax.plot(X[:, 0], X[:, 1], ".")
theta = torch.linspace(0, 2 * 3.14159265, 200)
for r in [1, 2, 3]:
scatter_ax.plot((r * theta.cos()).numpy(), (r * theta.sin()).numpy(), "r-", linewidth=1)
scatter_ax.set_xlim(-U, U)
scatter_ax.set_ylim(-U, U)
scatter_ax.axhline(0, color="gray", linewidth=0.5)
scatter_ax.axvline(0, color="gray", linewidth=0.5)
scatter_ax.grid(True, linewidth=0.5, alpha=0.5)
save_to = "outputs/epps_pulley_statistic.png"
plt.savefig(save_to, dpi=150)
plt.close()
def project_fourier(random: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
if random:
# Pick N random unit vectors.
V = torch.normal(0, 1, (N, D))
V = V / V.norm(p=2, dim=-1, keepdim=True)
else:
# Project onto the basis vectors.
V = torch.eye(D)
# Project X onto these vectors
P = X @ V.T # (B N)
P_t = P[..., None] * t # (B N t)
# Real and imaginary parts of fourier transform
# Mean over "batch" to get empirical characteristic function
re_F = P_t.cos().mean(0) # (N t)
im_F = P_t.sin().mean(0) # (N t)
return re_F, im_F
while True:
re_F, im_F = project_fourier()
err = (re_F - target).pow(2) + im_F.pow(2)
statistic = B * err @ weights # (N)
loss = statistic.mean()
loss.backward()
# Plot for the basis vector projections.
with torch.no_grad():
basis_re_F, basis_im_F = project_fourier(random=False)
plot(
t.numpy(),
X.detach().numpy(),
basis_re_F.numpy(),
basis_im_F.numpy(),
target.numpy(),
)
print("STEP:", step)
print(f"Mean: {X.mean().item()}, Std: {X.std().item()}")
print("Loss:", loss.item())
print("Grad magnitude:", X.grad.pow(2).sum().sqrt().item())
# Gradient step
X = torch.nn.Parameter(X.detach() - X.grad * 0.5)
step += 1
time.sleep(0.1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment