Skip to content

Instantly share code, notes, and snippets.

@lastforkbender
Created May 16, 2026 19:02
Show Gist options
  • Select an option

  • Save lastforkbender/cf1887c72db65edb8019bb692f8de219 to your computer and use it in GitHub Desktop.

Select an option

Save lastforkbender/cf1887c72db65edb8019bb692f8de219 to your computer and use it in GitHub Desktop.
NN with advanced IPI/IPE residual compression reasoning
# nn_ipe4.py
import math
import time
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------
# Utilities
# ---------------------------
def batched_topk_perm_fallback(group_idx: torch.Tensor, scores: torch.Tensor):
N = scores.size(0)
if N == 0:
return torch.empty(0, dtype=torch.long, device=scores.device)
score_desc_perm = torch.argsort(scores, descending=True, stable=True)
group_on_score_sorted = group_idx[score_desc_perm]
block_perm = torch.argsort(group_on_score_sorted, stable=True)
perm = score_desc_perm[block_perm]
return perm
def batched_topk_mask_hard(group_idx: torch.Tensor, scores: torch.Tensor, topk: int) -> torch.Tensor:
device = scores.device
N = scores.size(0)
if N == 0:
return torch.zeros(0, dtype=torch.bool, device=device)
G = int(group_idx.max().item()) + 1
perm = batched_topk_perm_fallback(group_idx, scores)
group_sorted = group_idx[perm]
counts = torch.bincount(group_sorted, minlength=G)
group_starts = torch.empty(G + 1, dtype=torch.long, device=device)
group_starts[0] = 0
group_starts[1:] = torch.cumsum(counts, dim=0)
ar = torch.arange(N, device=device)
positions = ar - group_starts[group_sorted]
topk_sorted_mask = positions < topk
mask = torch.zeros(N, dtype=torch.bool, device=device)
mask[perm] = topk_sorted_mask
return mask
def soft_group_attention(group_idx: torch.Tensor, scores: torch.Tensor, topk: int, temp: float = 1.0, hard_at_eval: bool = True):
"""
Differentiable soft attention per item. Hard topk mask returned at eval if hard_at_eval True.
"""
device = scores.device
N = scores.size(0)
if N == 0:
return torch.zeros(0, dtype=scores.dtype, device=device)
if (not torch.is_grad_enabled()) and hard_at_eval:
return batched_topk_mask_hard(group_idx, scores, topk).to(scores.dtype)
G = int(group_idx.max().item()) + 1
# per-group max
if hasattr(torch, 'scatter_reduce'):
max_per_group = torch.full((G,), -1e9, device=device, dtype=scores.dtype)
try:
max_per_group = max_per_group.scatter_reduce(0, group_idx, scores, reduce='amax', include_self=True)
except Exception:
max_per_group = torch.full((G,), -1e9, device=device, dtype=scores.dtype)
max_per_group = torch.scatter_reduce(max_per_group, 0, group_idx, scores, reduce='amax', include_self=True)
else:
max_vals = []
for g in range(G):
mask = group_idx == g
if mask.any():
max_vals.append(scores[mask].max())
else:
max_vals.append(torch.tensor(-1e9, device=device, dtype=scores.dtype))
max_per_group = torch.stack(max_vals)
shifted = scores - max_per_group[group_idx]
logits = shifted / (temp if temp > 0.0 else 1e-6)
exp_logits = logits.exp()
sum_exp = torch.zeros(G, device=device, dtype=exp_logits.dtype).index_add_(0, group_idx, exp_logits)
attn = exp_logits / (sum_exp[group_idx] + 1e-12)
return attn
# ---------------------------
# Interval Probability Evolver (differentiable prediction + state buffer)
# ---------------------------
class IntervalProbabilityEvolver(nn.Module):
"""
Maintains a stored IPI buffer (stateful) while producing a differentiable predicted IPI
used in the forward graph. The buffer is updated between steps with the detached
predicted IPI so that the recurrence is stable and doesn't retain the computation graph.
"""
def __init__(self, max_groups: int = 4096, momentum: float = 0.99, eps: float = 1e-6):
super().__init__()
self.max_groups = int(max_groups)
self.momentum = float(momentum)
self.eps = float(eps)
# buffer for stored state (not a learnable parameter); moved with model.to(device)
init = torch.full((self.max_groups,), 1.0 / (self.max_groups + 1e-12))
self.register_buffer('buffer_ipi', init)
def predict(self, group_idx: torch.Tensor, observed_attn: torch.Tensor) -> torch.Tensor:
"""
Compute differentiable predicted IPI for groups present in group_idx:
ipi_pred = momentum * buffer_group + (1 - momentum) * obs_avg_group
Returns tensor of shape (G,) where G = max(group_idx) + 1
NOTE: This prediction is differentiable: it uses buffer values (treated as constants),
and obs_avg computed from observed_attn (which depends on model outputs).
"""
device = self.buffer_ipi.device
if group_idx.numel() == 0:
return torch.zeros(0, device=device, dtype=self.buffer_ipi.dtype)
G_batch = int(group_idx.max().item()) + 1
obs_sum = torch.zeros(G_batch, device=device, dtype=observed_attn.dtype).index_add_(0, group_idx, observed_attn)
counts = torch.zeros(G_batch, device=device, dtype=obs_sum.dtype).index_add_(0, group_idx, torch.ones_like(observed_attn))
obs_avg = obs_sum / (counts + 1e-12)
# use stored buffer values for first G_batch groups (if buffer shorter, pad with small constant)
stored = self.buffer_ipi[:G_batch]
if stored.shape[0] < G_batch:
pad = torch.full((G_batch - stored.shape[0],), self.eps, device=device, dtype=stored.dtype)
stored = torch.cat([stored, pad], dim=0)
ipi_pred = self.momentum * stored + (1.0 - self.momentum) * obs_avg
# keep numerically safe
ipi_pred = ipi_pred.clamp(min=self.eps)
return ipi_pred
@torch.no_grad()
def commit(self, ipi_pred: torch.Tensor):
"""
Update internal buffer with the detached prediction for next-step recurrence.
This avoids retaining the previous forward graph while preserving state across steps.
"""
if ipi_pred.numel() == 0:
return
G_upd = min(ipi_pred.shape[0], self.max_groups)
self.buffer_ipi[:G_upd] = ipi_pred[:G_upd].detach().to(self.buffer_ipi.device)
def get_buffer_for_groups(self, group_idx: torch.Tensor):
if group_idx.numel() == 0:
return torch.zeros(0, device=self.buffer_ipi.device, dtype=self.buffer_ipi.dtype)
G = int(group_idx.max().item()) + 1
G = min(G, self.max_groups)
return self.buffer_ipi[:G]
@torch.no_grad()
def reset_(self, value: Optional[float] = None):
if value is None:
self.buffer_ipi.fill_(1.0 / (self.max_groups + 1e-12))
else:
self.buffer_ipi.fill_(float(value))
# ---------------------------
# Linear / Complex building blocks
# ---------------------------
class CayleyUnitaryComplex(nn.Module):
def __init__(self, dim, dtype=torch.complex64, eps=1e-5):
super().__init__()
self.dim = dim
self.A_real = nn.Parameter(torch.randn(dim, dim) * 0.01)
self.A_imag = nn.Parameter(torch.randn(dim, dim) * 0.01)
self._dtype = dtype
self.eps = eps
def forward(self):
A = (self.A_real + 1j * self.A_imag).to(self._dtype)
K = A - A.conj().transpose(-2, -1)
I = torch.eye(self.dim, dtype=K.dtype, device=K.device)
U = torch.linalg.solve(I - K - self.eps * I, I + K)
return U
class OrthogonalLinear(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.W = nn.Parameter(torch.randn(dim, dim) * (1.0 / math.sqrt(dim)))
def forward(self, x):
return x @ self.W
class LowRankComplexLinear(nn.Module):
def __init__(self, in_dim, out_dim, rank):
super().__init__()
scale_in = 1.0 / math.sqrt(max(1, in_dim))
scale_rank = 1.0 / math.sqrt(max(1, rank))
self.Ur = nn.Parameter(torch.randn(in_dim, rank) * scale_in)
self.Ui = nn.Parameter(torch.randn(in_dim, rank) * scale_in)
self.Vr = nn.Parameter(torch.randn(rank, out_dim) * scale_rank)
self.Vi = nn.Parameter(torch.randn(rank, out_dim) * scale_rank)
def forward(self, x_r, x_i):
z_r = x_r @ self.Ur - x_i @ self.Ui
z_i = x_r @ self.Ui + x_i @ self.Ur
out_r = z_r @ self.Vr - z_i @ self.Vi
out_i = z_r @ self.Vi + z_i @ self.Vr
return out_r, out_i
# ---------------------------
# Model with differentiable IPI + gradient modulation hooks
# ---------------------------
class CompressedResidualNetV3(nn.Module):
def __init__(
self,
feat_dim=128,
bottleneck=64,
residual_dim=16,
lowrank=8,
topk=8,
n_classes=10,
use_cayley=False,
enable_ipi: bool = True,
ipi_max_groups: int = 4096,
ipi_momentum: float = 0.99,
grad_modulation_strength: float = 0.5, # strength of gradient modulation from IPI
):
super().__init__()
self.feat_dim = feat_dim
self.bottleneck = bottleneck
self.residual_dim = residual_dim
self.topk = topk
self.use_cayley = use_cayley
self.enable_ipi = bool(enable_ipi)
self.grad_mod_strength = float(grad_modulation_strength)
self._grad_scale = 1.0 # runtime scalar used in hooks; updated each forward
# encoder + normalization
self.encoder = nn.Sequential(
nn.Linear(feat_dim, bottleneck),
nn.ReLU(),
nn.Linear(bottleneck, bottleneck),
)
self.encoder_ln = nn.LayerNorm(bottleneck)
self.score_head = nn.Linear(bottleneck, 1)
# residual projection
self.B_real = nn.Parameter(torch.randn(feat_dim, residual_dim) * 0.01)
self.B_imag = nn.Parameter(torch.randn(feat_dim, residual_dim) * 0.01)
self.alpha_ln = nn.LayerNorm(residual_dim)
self.alpha_mlp = nn.Sequential(
nn.Linear(residual_dim, residual_dim),
nn.ReLU(),
nn.Linear(residual_dim, residual_dim),
)
# lowrank complex projector
self.lowrank = LowRankComplexLinear(in_dim=residual_dim, out_dim=bottleneck, rank=lowrank)
# unitary or orthogonal (optional)
self.unitary = CayleyUnitaryComplex(residual_dim) if use_cayley else None
self.orth_proj = OrthogonalLinear(residual_dim) if not use_cayley else None
# learned gate to scale residual contribution (global)
self.res_gate = nn.Parameter(torch.tensor(0.0))
# fusion and classifier
self.fusion = nn.Linear(bottleneck, bottleneck)
self.classifier = nn.Sequential(
nn.Linear(bottleneck, bottleneck),
nn.ReLU(),
nn.Linear(bottleneck, n_classes),
)
# IPI module
self.ipi = IntervalProbabilityEvolver(max_groups=ipi_max_groups, momentum=ipi_momentum) if self.enable_ipi else None
# register gradient modulation hooks (one-time)
self._hooks_registered = False
self._register_grad_hooks()
def _register_grad_hooks(self):
if self._hooks_registered:
return
# hook function closes over self and uses self._grad_scale
def make_hook(param):
def hook(grad):
if grad is None:
return None
# simple multiplicative modulation: grad * (1 + strength * (scale - 1))
scale = float(self._grad_scale)
mod = 1.0 + self.grad_mod_strength * (scale - 1.0)
return grad * mod
return hook
for name, p in self.named_parameters():
if p.requires_grad:
p.register_hook(make_hook(p))
self._hooks_registered = True
def compute_residual(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r_real = x @ self.B_real
r_imag = x @ self.B_imag
return r_real, r_imag
def forward(self, occurrences: torch.Tensor, group_idx: torch.Tensor, temp: float = 1.0):
"""
occurrences: (N, feat_dim)
group_idx: (N,)
returns:
- if enable_ipi: (logits (G, n_classes), ipi_pred (G,)) where ipi_pred is the differentiable predicted IPI
- else: logits (G, n_classes)
Behavior:
- predicted IPI is differentiable (depends on observed attention which depends on model outputs)
- internal buffer is updated with detached ipi_pred for the next step (commit)
- a global grad-scale scalar is computed from ipi_pred and stored; registered hooks use it to modulate gradients during backward
"""
device = occurrences.device
N = occurrences.size(0)
if N == 0:
empty_logits = torch.empty((0, self.classifier[-1].out_features), device=device)
if self.enable_ipi:
return empty_logits, torch.empty((0,), device=device)
return empty_logits
G = int(group_idx.max().item()) + 1
z = self.encoder(occurrences)
z = self.encoder_ln(z)
scores = self.score_head(z).squeeze(-1)
# get differentiable attention (training) / hard mask (eval)
attn_weights = soft_group_attention(group_idx, scores, self.topk, temp=temp, hard_at_eval=True) # (N,)
# compute compressed vectors as weighted mean per group
attn_f = attn_weights.unsqueeze(-1)
sum_weighted = torch.zeros((G, z.size(1)), device=device, dtype=z.dtype).index_add_(0, group_idx, z * attn_f)
sum_weights = torch.zeros(G, device=device, dtype=z.dtype).index_add_(0, group_idx, attn_f.squeeze(-1))
comp_vecs = sum_weighted / sum_weights.clamp(min=1e-6).unsqueeze(-1)
# residual path
r_r_all, r_i_all = self.compute_residual(occurrences)
a_mag = torch.sqrt(r_r_all * r_r_all + r_i_all * r_i_all + 1e-9)
alpha = self.alpha_mlp(self.alpha_ln(a_mag))
a_r = alpha * (r_r_all / (a_mag + 1e-9))
a_i = alpha * (r_i_all / (a_mag + 1e-9))
a_r = a_r * attn_f
a_i = a_i * attn_f
sum_a_r = torch.zeros((G, self.residual_dim), device=device, dtype=a_r.dtype).index_add_(0, group_idx, a_r)
sum_a_i = torch.zeros((G, self.residual_dim), device=device, dtype=a_i.dtype).index_add_(0, group_idx, a_i)
sum_sq_r = torch.zeros((G, self.residual_dim), device=device, dtype=a_r.dtype).index_add_(0, group_idx, a_r * a_r)
sum_sq_i = torch.zeros((G, self.residual_dim), device=device, dtype=a_i.dtype).index_add_(0, group_idx, a_i * a_i)
pair_r = 0.5 * (sum_a_r * sum_a_r - sum_sq_r)
pair_i = 0.5 * (sum_a_i * sum_a_i - sum_sq_i)
if self.use_cayley and self.unitary is not None:
U = self.unitary()
pair_complex = pair_r + 1j * pair_i
pair_trans = pair_complex @ U
pair_r = pair_trans.real
pair_i = pair_trans.imag
elif self.orth_proj is not None:
pair_r = self.orth_proj(pair_r)
pair_i = self.orth_proj(pair_i)
out_r, out_i = self.lowrank(pair_r, pair_i)
res_real = out_r
# If IPI enabled: compute differentiable predicted ipi and use it to modulate residual per-group.
if self.enable_ipi and self.ipi is not None:
# predict differentiable ipi (shape G,)
ipi_pred = self.ipi.predict(group_idx, attn_weights)
ipi_f = ipi_pred.unsqueeze(-1).to(res_real.dtype)
gated_res = torch.sigmoid(self.res_gate) * (res_real * ipi_f)
# compute global grad-scale from ipi_pred (example: normalized mean -> scale in (0.5,2.0) depending on strength)
with torch.no_grad():
mean_ipi = float(ipi_pred.mean().detach().cpu().item()) if ipi_pred.numel() > 0 else 0.0
# non-detached version for gradient path: compute scale as simple differentiable transform of mean(ipi_pred)
# but we keep stored _grad_scale for hooks updated deterministically from detached ipi_pred after forward.
# For stability, compute grad_scale_pred (differentiable) but only commit detached scalar to hooks.
grad_scale_pred = 1.0 + self.grad_mod_strength * (ipi_pred.mean() - ipi_pred.mean().detach()) # small differentiable bias; default leaves scale near 1
# set runtime hook scalar to detached projection of ipi_pred mean (hooks use detached scalar)
self._grad_scale = 1.0 + self.grad_mod_strength * (mean_ipi - 0.0) # baseline 0.0 -> scale 1.0
# commit buffer state for next step (detach inside commit)
self.ipi.commit(ipi_pred)
else:
gated_res = torch.sigmoid(self.res_gate) * res_real
ipi_pred = None
self._grad_scale = 1.0
fused = self.fusion(comp_vecs + gated_res)
logits = self.classifier(fused)
if self.enable_ipi:
return logits, ipi_pred
return logits
def compute_ipi_loss(self, group_idx: torch.Tensor, attn_weights: torch.Tensor, reduction: str = 'mean'):
"""
Auxiliary MSE between differentiable predicted IPI (recomputed) and observed average per-group attention.
Useful to encourage coupling in training.
"""
if not self.enable_ipi or self.ipi is None:
return torch.tensor(0.0, device=attn_weights.device)
device = attn_weights.device
G = int(group_idx.max().item()) + 1
obs_sum = torch.zeros(G, device=device, dtype=attn_weights.dtype).index_add_(0, group_idx, attn_weights)
counts = torch.zeros(G, device=device, dtype=obs_sum.dtype).index_add_(0, group_idx, torch.ones_like(attn_weights))
obs_avg = obs_sum / (counts + 1e-12)
ipi_pred = self.ipi.predict(group_idx, attn_weights)
# ipi_pred shape is G; compute mse
loss = F.mse_loss(ipi_pred, obs_avg, reduction=reduction)
return loss
# ---------------------------
# Synthetic dataset and training helpers
# ---------------------------
def make_synthetic_batch(G=128, occ_per_group=16, feat_dim=128, n_classes=10, device='cuda'):
N = G * occ_per_group
base = torch.randn(G, feat_dim, device=device)
occurrences = base.repeat_interleave(occ_per_group, dim=0) + 0.1 * torch.randn(N, feat_dim, device=device)
group_idx = torch.arange(G, device=device).repeat_interleave(occ_per_group)
weight = torch.randn(feat_dim, n_classes, device=device)
logits = base @ weight
targets = logits.argmax(dim=1)
return occurrences, group_idx, targets
def train_epoch(model, optim, device, scaler: Optional[torch.cuda.amp.GradScaler] = None, batches=50, G=128, occ_per_group=16, ipi_loss_scale: float = 0.0):
model.train()
total_loss = 0.0
total_correct = 0
total_examples = 0
t0 = time.time()
use_amp = scaler is not None
for b in range(batches):
occ, grp, targets = make_synthetic_batch(G=G, occ_per_group=occ_per_group, device=device)
with torch.cuda.amp.autocast(enabled=use_amp):
out = model(occ, grp, temp=max(0.1, 1.0 - 0.001 * b))
if isinstance(out, tuple):
logits, ipi_pred = out
else:
logits = out
ipi_pred = None
loss = F.cross_entropy(logits, targets)
if ipi_loss_scale > 0.0 and ipi_pred is not None:
# recompute scores/attn (differentiable) to compute ipi coupling loss
z = model.encoder(occ)
z = model.encoder_ln(z)
scores = model.score_head(z).squeeze(-1)
attn = soft_group_attention(grp, scores, model.topk, temp=max(0.1, 1.0 - 0.001 * b), hard_at_eval=False)
ipi_loss = model.compute_ipi_loss(grp, attn, reduction='mean')
loss = loss + ipi_loss_scale * ipi_loss
optim.zero_grad(set_to_none=True)
if use_amp:
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else:
loss.backward()
optim.step()
total_loss += loss.item() * G
preds = logits.argmax(dim=1)
total_correct += (preds == targets).sum().item()
total_examples += G
t1 = time.time()
return total_loss / total_examples, total_correct / total_examples, (t1 - t0)
def run_training():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CompressedResidualNetV3(
feat_dim=128,
bottleneck=64,
residual_dim=16,
lowrank=8,
topk=8,
n_classes=10,
use_cayley=False,
enable_ipi=True,
ipi_max_groups=1024,
ipi_momentum=0.98,
grad_modulation_strength=0.5,
).to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
epochs = 5
for ep in range(epochs):
loss, acc, sec = train_epoch(model, optim, device, scaler=scaler, batches=30, G=256, occ_per_group=8, ipi_loss_scale=0.05)
print(f"Epoch {ep+1}/{epochs} — loss: {loss:.4f}, acc: {acc:.4f}, epoch_sec: {sec:.2f}")
model.eval()
with torch.no_grad():
occ, grp, targets = make_synthetic_batch(G=256, occ_per_group=8, device=device)
out = model(occ, grp)
if isinstance(out, tuple):
logits, ipi_vals = out
else:
logits = out
acc = (logits.argmax(dim=1) == targets).float().mean().item()
print("Eval acc:", acc)
return model
if __name__ == "__main__":
run_training()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment