Created
May 16, 2026 19:02
-
-
Save lastforkbender/cf1887c72db65edb8019bb692f8de219 to your computer and use it in GitHub Desktop.
NN with advanced IPI/IPE residual compression reasoning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # nn_ipe4.py | |
| import math | |
| import time | |
| from typing import Tuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # --------------------------- | |
| # Utilities | |
| # --------------------------- | |
| def batched_topk_perm_fallback(group_idx: torch.Tensor, scores: torch.Tensor): | |
| N = scores.size(0) | |
| if N == 0: | |
| return torch.empty(0, dtype=torch.long, device=scores.device) | |
| score_desc_perm = torch.argsort(scores, descending=True, stable=True) | |
| group_on_score_sorted = group_idx[score_desc_perm] | |
| block_perm = torch.argsort(group_on_score_sorted, stable=True) | |
| perm = score_desc_perm[block_perm] | |
| return perm | |
| def batched_topk_mask_hard(group_idx: torch.Tensor, scores: torch.Tensor, topk: int) -> torch.Tensor: | |
| device = scores.device | |
| N = scores.size(0) | |
| if N == 0: | |
| return torch.zeros(0, dtype=torch.bool, device=device) | |
| G = int(group_idx.max().item()) + 1 | |
| perm = batched_topk_perm_fallback(group_idx, scores) | |
| group_sorted = group_idx[perm] | |
| counts = torch.bincount(group_sorted, minlength=G) | |
| group_starts = torch.empty(G + 1, dtype=torch.long, device=device) | |
| group_starts[0] = 0 | |
| group_starts[1:] = torch.cumsum(counts, dim=0) | |
| ar = torch.arange(N, device=device) | |
| positions = ar - group_starts[group_sorted] | |
| topk_sorted_mask = positions < topk | |
| mask = torch.zeros(N, dtype=torch.bool, device=device) | |
| mask[perm] = topk_sorted_mask | |
| return mask | |
| def soft_group_attention(group_idx: torch.Tensor, scores: torch.Tensor, topk: int, temp: float = 1.0, hard_at_eval: bool = True): | |
| """ | |
| Differentiable soft attention per item. Hard topk mask returned at eval if hard_at_eval True. | |
| """ | |
| device = scores.device | |
| N = scores.size(0) | |
| if N == 0: | |
| return torch.zeros(0, dtype=scores.dtype, device=device) | |
| if (not torch.is_grad_enabled()) and hard_at_eval: | |
| return batched_topk_mask_hard(group_idx, scores, topk).to(scores.dtype) | |
| G = int(group_idx.max().item()) + 1 | |
| # per-group max | |
| if hasattr(torch, 'scatter_reduce'): | |
| max_per_group = torch.full((G,), -1e9, device=device, dtype=scores.dtype) | |
| try: | |
| max_per_group = max_per_group.scatter_reduce(0, group_idx, scores, reduce='amax', include_self=True) | |
| except Exception: | |
| max_per_group = torch.full((G,), -1e9, device=device, dtype=scores.dtype) | |
| max_per_group = torch.scatter_reduce(max_per_group, 0, group_idx, scores, reduce='amax', include_self=True) | |
| else: | |
| max_vals = [] | |
| for g in range(G): | |
| mask = group_idx == g | |
| if mask.any(): | |
| max_vals.append(scores[mask].max()) | |
| else: | |
| max_vals.append(torch.tensor(-1e9, device=device, dtype=scores.dtype)) | |
| max_per_group = torch.stack(max_vals) | |
| shifted = scores - max_per_group[group_idx] | |
| logits = shifted / (temp if temp > 0.0 else 1e-6) | |
| exp_logits = logits.exp() | |
| sum_exp = torch.zeros(G, device=device, dtype=exp_logits.dtype).index_add_(0, group_idx, exp_logits) | |
| attn = exp_logits / (sum_exp[group_idx] + 1e-12) | |
| return attn | |
| # --------------------------- | |
| # Interval Probability Evolver (differentiable prediction + state buffer) | |
| # --------------------------- | |
| class IntervalProbabilityEvolver(nn.Module): | |
| """ | |
| Maintains a stored IPI buffer (stateful) while producing a differentiable predicted IPI | |
| used in the forward graph. The buffer is updated between steps with the detached | |
| predicted IPI so that the recurrence is stable and doesn't retain the computation graph. | |
| """ | |
| def __init__(self, max_groups: int = 4096, momentum: float = 0.99, eps: float = 1e-6): | |
| super().__init__() | |
| self.max_groups = int(max_groups) | |
| self.momentum = float(momentum) | |
| self.eps = float(eps) | |
| # buffer for stored state (not a learnable parameter); moved with model.to(device) | |
| init = torch.full((self.max_groups,), 1.0 / (self.max_groups + 1e-12)) | |
| self.register_buffer('buffer_ipi', init) | |
| def predict(self, group_idx: torch.Tensor, observed_attn: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute differentiable predicted IPI for groups present in group_idx: | |
| ipi_pred = momentum * buffer_group + (1 - momentum) * obs_avg_group | |
| Returns tensor of shape (G,) where G = max(group_idx) + 1 | |
| NOTE: This prediction is differentiable: it uses buffer values (treated as constants), | |
| and obs_avg computed from observed_attn (which depends on model outputs). | |
| """ | |
| device = self.buffer_ipi.device | |
| if group_idx.numel() == 0: | |
| return torch.zeros(0, device=device, dtype=self.buffer_ipi.dtype) | |
| G_batch = int(group_idx.max().item()) + 1 | |
| obs_sum = torch.zeros(G_batch, device=device, dtype=observed_attn.dtype).index_add_(0, group_idx, observed_attn) | |
| counts = torch.zeros(G_batch, device=device, dtype=obs_sum.dtype).index_add_(0, group_idx, torch.ones_like(observed_attn)) | |
| obs_avg = obs_sum / (counts + 1e-12) | |
| # use stored buffer values for first G_batch groups (if buffer shorter, pad with small constant) | |
| stored = self.buffer_ipi[:G_batch] | |
| if stored.shape[0] < G_batch: | |
| pad = torch.full((G_batch - stored.shape[0],), self.eps, device=device, dtype=stored.dtype) | |
| stored = torch.cat([stored, pad], dim=0) | |
| ipi_pred = self.momentum * stored + (1.0 - self.momentum) * obs_avg | |
| # keep numerically safe | |
| ipi_pred = ipi_pred.clamp(min=self.eps) | |
| return ipi_pred | |
| @torch.no_grad() | |
| def commit(self, ipi_pred: torch.Tensor): | |
| """ | |
| Update internal buffer with the detached prediction for next-step recurrence. | |
| This avoids retaining the previous forward graph while preserving state across steps. | |
| """ | |
| if ipi_pred.numel() == 0: | |
| return | |
| G_upd = min(ipi_pred.shape[0], self.max_groups) | |
| self.buffer_ipi[:G_upd] = ipi_pred[:G_upd].detach().to(self.buffer_ipi.device) | |
| def get_buffer_for_groups(self, group_idx: torch.Tensor): | |
| if group_idx.numel() == 0: | |
| return torch.zeros(0, device=self.buffer_ipi.device, dtype=self.buffer_ipi.dtype) | |
| G = int(group_idx.max().item()) + 1 | |
| G = min(G, self.max_groups) | |
| return self.buffer_ipi[:G] | |
| @torch.no_grad() | |
| def reset_(self, value: Optional[float] = None): | |
| if value is None: | |
| self.buffer_ipi.fill_(1.0 / (self.max_groups + 1e-12)) | |
| else: | |
| self.buffer_ipi.fill_(float(value)) | |
| # --------------------------- | |
| # Linear / Complex building blocks | |
| # --------------------------- | |
| class CayleyUnitaryComplex(nn.Module): | |
| def __init__(self, dim, dtype=torch.complex64, eps=1e-5): | |
| super().__init__() | |
| self.dim = dim | |
| self.A_real = nn.Parameter(torch.randn(dim, dim) * 0.01) | |
| self.A_imag = nn.Parameter(torch.randn(dim, dim) * 0.01) | |
| self._dtype = dtype | |
| self.eps = eps | |
| def forward(self): | |
| A = (self.A_real + 1j * self.A_imag).to(self._dtype) | |
| K = A - A.conj().transpose(-2, -1) | |
| I = torch.eye(self.dim, dtype=K.dtype, device=K.device) | |
| U = torch.linalg.solve(I - K - self.eps * I, I + K) | |
| return U | |
| class OrthogonalLinear(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| self.W = nn.Parameter(torch.randn(dim, dim) * (1.0 / math.sqrt(dim))) | |
| def forward(self, x): | |
| return x @ self.W | |
| class LowRankComplexLinear(nn.Module): | |
| def __init__(self, in_dim, out_dim, rank): | |
| super().__init__() | |
| scale_in = 1.0 / math.sqrt(max(1, in_dim)) | |
| scale_rank = 1.0 / math.sqrt(max(1, rank)) | |
| self.Ur = nn.Parameter(torch.randn(in_dim, rank) * scale_in) | |
| self.Ui = nn.Parameter(torch.randn(in_dim, rank) * scale_in) | |
| self.Vr = nn.Parameter(torch.randn(rank, out_dim) * scale_rank) | |
| self.Vi = nn.Parameter(torch.randn(rank, out_dim) * scale_rank) | |
| def forward(self, x_r, x_i): | |
| z_r = x_r @ self.Ur - x_i @ self.Ui | |
| z_i = x_r @ self.Ui + x_i @ self.Ur | |
| out_r = z_r @ self.Vr - z_i @ self.Vi | |
| out_i = z_r @ self.Vi + z_i @ self.Vr | |
| return out_r, out_i | |
| # --------------------------- | |
| # Model with differentiable IPI + gradient modulation hooks | |
| # --------------------------- | |
| class CompressedResidualNetV3(nn.Module): | |
| def __init__( | |
| self, | |
| feat_dim=128, | |
| bottleneck=64, | |
| residual_dim=16, | |
| lowrank=8, | |
| topk=8, | |
| n_classes=10, | |
| use_cayley=False, | |
| enable_ipi: bool = True, | |
| ipi_max_groups: int = 4096, | |
| ipi_momentum: float = 0.99, | |
| grad_modulation_strength: float = 0.5, # strength of gradient modulation from IPI | |
| ): | |
| super().__init__() | |
| self.feat_dim = feat_dim | |
| self.bottleneck = bottleneck | |
| self.residual_dim = residual_dim | |
| self.topk = topk | |
| self.use_cayley = use_cayley | |
| self.enable_ipi = bool(enable_ipi) | |
| self.grad_mod_strength = float(grad_modulation_strength) | |
| self._grad_scale = 1.0 # runtime scalar used in hooks; updated each forward | |
| # encoder + normalization | |
| self.encoder = nn.Sequential( | |
| nn.Linear(feat_dim, bottleneck), | |
| nn.ReLU(), | |
| nn.Linear(bottleneck, bottleneck), | |
| ) | |
| self.encoder_ln = nn.LayerNorm(bottleneck) | |
| self.score_head = nn.Linear(bottleneck, 1) | |
| # residual projection | |
| self.B_real = nn.Parameter(torch.randn(feat_dim, residual_dim) * 0.01) | |
| self.B_imag = nn.Parameter(torch.randn(feat_dim, residual_dim) * 0.01) | |
| self.alpha_ln = nn.LayerNorm(residual_dim) | |
| self.alpha_mlp = nn.Sequential( | |
| nn.Linear(residual_dim, residual_dim), | |
| nn.ReLU(), | |
| nn.Linear(residual_dim, residual_dim), | |
| ) | |
| # lowrank complex projector | |
| self.lowrank = LowRankComplexLinear(in_dim=residual_dim, out_dim=bottleneck, rank=lowrank) | |
| # unitary or orthogonal (optional) | |
| self.unitary = CayleyUnitaryComplex(residual_dim) if use_cayley else None | |
| self.orth_proj = OrthogonalLinear(residual_dim) if not use_cayley else None | |
| # learned gate to scale residual contribution (global) | |
| self.res_gate = nn.Parameter(torch.tensor(0.0)) | |
| # fusion and classifier | |
| self.fusion = nn.Linear(bottleneck, bottleneck) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(bottleneck, bottleneck), | |
| nn.ReLU(), | |
| nn.Linear(bottleneck, n_classes), | |
| ) | |
| # IPI module | |
| self.ipi = IntervalProbabilityEvolver(max_groups=ipi_max_groups, momentum=ipi_momentum) if self.enable_ipi else None | |
| # register gradient modulation hooks (one-time) | |
| self._hooks_registered = False | |
| self._register_grad_hooks() | |
| def _register_grad_hooks(self): | |
| if self._hooks_registered: | |
| return | |
| # hook function closes over self and uses self._grad_scale | |
| def make_hook(param): | |
| def hook(grad): | |
| if grad is None: | |
| return None | |
| # simple multiplicative modulation: grad * (1 + strength * (scale - 1)) | |
| scale = float(self._grad_scale) | |
| mod = 1.0 + self.grad_mod_strength * (scale - 1.0) | |
| return grad * mod | |
| return hook | |
| for name, p in self.named_parameters(): | |
| if p.requires_grad: | |
| p.register_hook(make_hook(p)) | |
| self._hooks_registered = True | |
| def compute_residual(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| r_real = x @ self.B_real | |
| r_imag = x @ self.B_imag | |
| return r_real, r_imag | |
| def forward(self, occurrences: torch.Tensor, group_idx: torch.Tensor, temp: float = 1.0): | |
| """ | |
| occurrences: (N, feat_dim) | |
| group_idx: (N,) | |
| returns: | |
| - if enable_ipi: (logits (G, n_classes), ipi_pred (G,)) where ipi_pred is the differentiable predicted IPI | |
| - else: logits (G, n_classes) | |
| Behavior: | |
| - predicted IPI is differentiable (depends on observed attention which depends on model outputs) | |
| - internal buffer is updated with detached ipi_pred for the next step (commit) | |
| - a global grad-scale scalar is computed from ipi_pred and stored; registered hooks use it to modulate gradients during backward | |
| """ | |
| device = occurrences.device | |
| N = occurrences.size(0) | |
| if N == 0: | |
| empty_logits = torch.empty((0, self.classifier[-1].out_features), device=device) | |
| if self.enable_ipi: | |
| return empty_logits, torch.empty((0,), device=device) | |
| return empty_logits | |
| G = int(group_idx.max().item()) + 1 | |
| z = self.encoder(occurrences) | |
| z = self.encoder_ln(z) | |
| scores = self.score_head(z).squeeze(-1) | |
| # get differentiable attention (training) / hard mask (eval) | |
| attn_weights = soft_group_attention(group_idx, scores, self.topk, temp=temp, hard_at_eval=True) # (N,) | |
| # compute compressed vectors as weighted mean per group | |
| attn_f = attn_weights.unsqueeze(-1) | |
| sum_weighted = torch.zeros((G, z.size(1)), device=device, dtype=z.dtype).index_add_(0, group_idx, z * attn_f) | |
| sum_weights = torch.zeros(G, device=device, dtype=z.dtype).index_add_(0, group_idx, attn_f.squeeze(-1)) | |
| comp_vecs = sum_weighted / sum_weights.clamp(min=1e-6).unsqueeze(-1) | |
| # residual path | |
| r_r_all, r_i_all = self.compute_residual(occurrences) | |
| a_mag = torch.sqrt(r_r_all * r_r_all + r_i_all * r_i_all + 1e-9) | |
| alpha = self.alpha_mlp(self.alpha_ln(a_mag)) | |
| a_r = alpha * (r_r_all / (a_mag + 1e-9)) | |
| a_i = alpha * (r_i_all / (a_mag + 1e-9)) | |
| a_r = a_r * attn_f | |
| a_i = a_i * attn_f | |
| sum_a_r = torch.zeros((G, self.residual_dim), device=device, dtype=a_r.dtype).index_add_(0, group_idx, a_r) | |
| sum_a_i = torch.zeros((G, self.residual_dim), device=device, dtype=a_i.dtype).index_add_(0, group_idx, a_i) | |
| sum_sq_r = torch.zeros((G, self.residual_dim), device=device, dtype=a_r.dtype).index_add_(0, group_idx, a_r * a_r) | |
| sum_sq_i = torch.zeros((G, self.residual_dim), device=device, dtype=a_i.dtype).index_add_(0, group_idx, a_i * a_i) | |
| pair_r = 0.5 * (sum_a_r * sum_a_r - sum_sq_r) | |
| pair_i = 0.5 * (sum_a_i * sum_a_i - sum_sq_i) | |
| if self.use_cayley and self.unitary is not None: | |
| U = self.unitary() | |
| pair_complex = pair_r + 1j * pair_i | |
| pair_trans = pair_complex @ U | |
| pair_r = pair_trans.real | |
| pair_i = pair_trans.imag | |
| elif self.orth_proj is not None: | |
| pair_r = self.orth_proj(pair_r) | |
| pair_i = self.orth_proj(pair_i) | |
| out_r, out_i = self.lowrank(pair_r, pair_i) | |
| res_real = out_r | |
| # If IPI enabled: compute differentiable predicted ipi and use it to modulate residual per-group. | |
| if self.enable_ipi and self.ipi is not None: | |
| # predict differentiable ipi (shape G,) | |
| ipi_pred = self.ipi.predict(group_idx, attn_weights) | |
| ipi_f = ipi_pred.unsqueeze(-1).to(res_real.dtype) | |
| gated_res = torch.sigmoid(self.res_gate) * (res_real * ipi_f) | |
| # compute global grad-scale from ipi_pred (example: normalized mean -> scale in (0.5,2.0) depending on strength) | |
| with torch.no_grad(): | |
| mean_ipi = float(ipi_pred.mean().detach().cpu().item()) if ipi_pred.numel() > 0 else 0.0 | |
| # non-detached version for gradient path: compute scale as simple differentiable transform of mean(ipi_pred) | |
| # but we keep stored _grad_scale for hooks updated deterministically from detached ipi_pred after forward. | |
| # For stability, compute grad_scale_pred (differentiable) but only commit detached scalar to hooks. | |
| grad_scale_pred = 1.0 + self.grad_mod_strength * (ipi_pred.mean() - ipi_pred.mean().detach()) # small differentiable bias; default leaves scale near 1 | |
| # set runtime hook scalar to detached projection of ipi_pred mean (hooks use detached scalar) | |
| self._grad_scale = 1.0 + self.grad_mod_strength * (mean_ipi - 0.0) # baseline 0.0 -> scale 1.0 | |
| # commit buffer state for next step (detach inside commit) | |
| self.ipi.commit(ipi_pred) | |
| else: | |
| gated_res = torch.sigmoid(self.res_gate) * res_real | |
| ipi_pred = None | |
| self._grad_scale = 1.0 | |
| fused = self.fusion(comp_vecs + gated_res) | |
| logits = self.classifier(fused) | |
| if self.enable_ipi: | |
| return logits, ipi_pred | |
| return logits | |
| def compute_ipi_loss(self, group_idx: torch.Tensor, attn_weights: torch.Tensor, reduction: str = 'mean'): | |
| """ | |
| Auxiliary MSE between differentiable predicted IPI (recomputed) and observed average per-group attention. | |
| Useful to encourage coupling in training. | |
| """ | |
| if not self.enable_ipi or self.ipi is None: | |
| return torch.tensor(0.0, device=attn_weights.device) | |
| device = attn_weights.device | |
| G = int(group_idx.max().item()) + 1 | |
| obs_sum = torch.zeros(G, device=device, dtype=attn_weights.dtype).index_add_(0, group_idx, attn_weights) | |
| counts = torch.zeros(G, device=device, dtype=obs_sum.dtype).index_add_(0, group_idx, torch.ones_like(attn_weights)) | |
| obs_avg = obs_sum / (counts + 1e-12) | |
| ipi_pred = self.ipi.predict(group_idx, attn_weights) | |
| # ipi_pred shape is G; compute mse | |
| loss = F.mse_loss(ipi_pred, obs_avg, reduction=reduction) | |
| return loss | |
| # --------------------------- | |
| # Synthetic dataset and training helpers | |
| # --------------------------- | |
| def make_synthetic_batch(G=128, occ_per_group=16, feat_dim=128, n_classes=10, device='cuda'): | |
| N = G * occ_per_group | |
| base = torch.randn(G, feat_dim, device=device) | |
| occurrences = base.repeat_interleave(occ_per_group, dim=0) + 0.1 * torch.randn(N, feat_dim, device=device) | |
| group_idx = torch.arange(G, device=device).repeat_interleave(occ_per_group) | |
| weight = torch.randn(feat_dim, n_classes, device=device) | |
| logits = base @ weight | |
| targets = logits.argmax(dim=1) | |
| return occurrences, group_idx, targets | |
| def train_epoch(model, optim, device, scaler: Optional[torch.cuda.amp.GradScaler] = None, batches=50, G=128, occ_per_group=16, ipi_loss_scale: float = 0.0): | |
| model.train() | |
| total_loss = 0.0 | |
| total_correct = 0 | |
| total_examples = 0 | |
| t0 = time.time() | |
| use_amp = scaler is not None | |
| for b in range(batches): | |
| occ, grp, targets = make_synthetic_batch(G=G, occ_per_group=occ_per_group, device=device) | |
| with torch.cuda.amp.autocast(enabled=use_amp): | |
| out = model(occ, grp, temp=max(0.1, 1.0 - 0.001 * b)) | |
| if isinstance(out, tuple): | |
| logits, ipi_pred = out | |
| else: | |
| logits = out | |
| ipi_pred = None | |
| loss = F.cross_entropy(logits, targets) | |
| if ipi_loss_scale > 0.0 and ipi_pred is not None: | |
| # recompute scores/attn (differentiable) to compute ipi coupling loss | |
| z = model.encoder(occ) | |
| z = model.encoder_ln(z) | |
| scores = model.score_head(z).squeeze(-1) | |
| attn = soft_group_attention(grp, scores, model.topk, temp=max(0.1, 1.0 - 0.001 * b), hard_at_eval=False) | |
| ipi_loss = model.compute_ipi_loss(grp, attn, reduction='mean') | |
| loss = loss + ipi_loss_scale * ipi_loss | |
| optim.zero_grad(set_to_none=True) | |
| if use_amp: | |
| scaler.scale(loss).backward() | |
| scaler.step(optim) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| optim.step() | |
| total_loss += loss.item() * G | |
| preds = logits.argmax(dim=1) | |
| total_correct += (preds == targets).sum().item() | |
| total_examples += G | |
| t1 = time.time() | |
| return total_loss / total_examples, total_correct / total_examples, (t1 - t0) | |
| def run_training(): | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = CompressedResidualNetV3( | |
| feat_dim=128, | |
| bottleneck=64, | |
| residual_dim=16, | |
| lowrank=8, | |
| topk=8, | |
| n_classes=10, | |
| use_cayley=False, | |
| enable_ipi=True, | |
| ipi_max_groups=1024, | |
| ipi_momentum=0.98, | |
| grad_modulation_strength=0.5, | |
| ).to(device) | |
| optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) | |
| scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None | |
| epochs = 5 | |
| for ep in range(epochs): | |
| loss, acc, sec = train_epoch(model, optim, device, scaler=scaler, batches=30, G=256, occ_per_group=8, ipi_loss_scale=0.05) | |
| print(f"Epoch {ep+1}/{epochs} — loss: {loss:.4f}, acc: {acc:.4f}, epoch_sec: {sec:.2f}") | |
| model.eval() | |
| with torch.no_grad(): | |
| occ, grp, targets = make_synthetic_batch(G=256, occ_per_group=8, device=device) | |
| out = model(occ, grp) | |
| if isinstance(out, tuple): | |
| logits, ipi_vals = out | |
| else: | |
| logits = out | |
| acc = (logits.argmax(dim=1) == targets).float().mean().item() | |
| print("Eval acc:", acc) | |
| return model | |
| if __name__ == "__main__": | |
| run_training() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment