Created
May 16, 2026 02:42
-
-
Save lastforkbender/59b6754c4d9a008b391c36635399b4b3 to your computer and use it in GitHub Desktop.
BNN-AGI with lower/upper complex number inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import math | |
| class MetaController(nn.Module): | |
| def __init__(self, node_feature_dim, hidden_dim, window_size, num_targets, | |
| meta_steps=3, use_layernorm=True, | |
| spec_k=4, spec_alpha=0.1, layer_linear_attenuation=True, | |
| alpha_likeness=0.05, R_max=3, beta=2.0, decay_factor=0.6): | |
| super().__init__() | |
| self.node_feature_dim = node_feature_dim | |
| self.hidden_dim = hidden_dim | |
| self.window_size = window_size | |
| self.num_targets = num_targets | |
| self.meta_steps = meta_steps | |
| enc_layers = [ | |
| nn.Linear(node_feature_dim + 1, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU() | |
| ] | |
| if use_layernorm: | |
| enc_layers.append(nn.LayerNorm(hidden_dim)) | |
| self.encoder = nn.Sequential(*enc_layers) | |
| self.delta_phase_head = nn.Linear(hidden_dim, hidden_dim) | |
| self.delta_amp_head = nn.Linear(hidden_dim, hidden_dim) | |
| self.delta_bspline_head = nn.Linear(hidden_dim, window_size) | |
| self.gate = nn.Linear(hidden_dim, 1) | |
| # Complex projection params (real/imag pairs implemented as real tensors) | |
| h = hidden_dim | |
| self.Wu_r = nn.Parameter(torch.randn(h, h) * 0.01) | |
| self.Wu_i = nn.Parameter(torch.randn(h, h) * 0.01) | |
| self.Wl_r = nn.Parameter(torch.randn(h, h) * 0.01) | |
| self.Wl_i = nn.Parameter(torch.randn(h, h) * 0.01) | |
| self.Pr = nn.Parameter(torch.randn(h, h) * 0.01) | |
| self.Pi = nn.Parameter(torch.randn(h, h) * 0.01) | |
| # control hyperparams | |
| self.spec_k = spec_k | |
| self.spec_alpha = spec_alpha | |
| self.layer_linear_attenuation = layer_linear_attenuation | |
| self.alpha_likeness = alpha_likeness | |
| self.R_max = R_max | |
| self.beta = beta | |
| self.decay_factor = decay_factor | |
| def _complex_matmul_pair(self, x_real, x_imag, wr, wi): | |
| # x: (K,H), wr/wi: (H,h) -> out (K,h) real/imag | |
| real = x_real @ wr - x_imag @ wi | |
| imag = x_real @ wi + x_imag @ wr | |
| return real, imag | |
| def _spectral_score(self, t, axis=1, k=None, eps=1e-12): | |
| # t: real tensor (..., L) | |
| if k is None: | |
| k = self.spec_k | |
| # ensure last dim is frequency dim | |
| orig_dims = t.ndim | |
| if axis != -1: | |
| t = t.movedim(axis, -1) | |
| # rfft over last dim | |
| spec = torch.fft.rfft(t, dim=-1) | |
| psd = (spec.abs() ** 2) | |
| # sum over non-frequency dims to get per-item energy across freq | |
| # shape (..., Freq) | |
| low = psd[..., :k].sum(dim=tuple(range(psd.ndim - 1))) if psd.ndim > 1 else psd[..., :k].sum() | |
| total = psd.sum(dim=tuple(range(psd.ndim - 1))) + eps if psd.ndim > 1 else psd.sum() + eps | |
| # If low/total are scalars, convert to tensor | |
| if not torch.is_tensor(low): | |
| low = torch.tensor(low, device=t.device, dtype=t.dtype) | |
| total = torch.tensor(total, device=t.device, dtype=t.dtype) | |
| score = (low / total).clamp(0.0, 1.0) | |
| # return as shape broadcastable (...,1) | |
| return score.unsqueeze(-1) | |
| def forward(self, node_feats, curvature, params): | |
| # node_feats: (K, H), curvature: (K,1) | |
| phase = params['phase'] # (K, H) | |
| amp = params['amp'] # (K, H) | |
| bs_logits = params['bspline_logits'] # (K, W) | |
| # ensure real tensors where expected | |
| device = node_feats.device | |
| dtype = node_feats.dtype | |
| for step in range(self.meta_steps): | |
| enc_in = torch.cat([node_feats, curvature], dim=1) | |
| h = self.encoder(enc_in) | |
| gate = torch.sigmoid(self.gate(h)) # (K,1) | |
| d_phase = self.delta_phase_head(h) # (K,H) real | |
| d_amp = self.delta_amp_head(h) | |
| d_bs = self.delta_bspline_head(h) # (K,W) | |
| # spectral pre-score (use magnitude across hidden dim for phase/amp) | |
| sp_p = self._spectral_score(d_phase.abs(), axis=1) | |
| sp_a = self._spectral_score(d_amp.abs(), axis=1) | |
| sp_b = self._spectral_score(d_bs.abs(), axis=1) | |
| spec_score = (sp_p + sp_a + sp_b) / 3.0 # (K,1) | |
| # UCS/LCS projections and LAS | |
| # Build complex node_rot proxies from phase & amp for targets: s = amp * exp(i*phase) | |
| s_real = amp * torch.cos(phase) | |
| s_imag = amp * torch.sin(phase) | |
| # project to UCS and LCS using real/imag paired mats | |
| su_r, su_i = self._complex_matmul_pair(s_real, s_imag, self.Wu_r, self.Wu_i) | |
| sl_r, sl_i = self._complex_matmul_pair(s_real, s_imag, self.Wl_r, self.Wl_i) | |
| # map sl -> p via P | |
| p_r, p_i = self._complex_matmul_pair(sl_r, sl_i, self.Pr, self.Pi) | |
| # likeness LAS per target: real inner product between su and conj(p) normalized | |
| # inner = sum_k Re( su_k * conj(p_k) ) = sum( su_r*p_r + su_i*p_i ) | |
| inner = torch.sum(su_r * p_r + su_i * p_i, dim=1, keepdim=True) # (K,1) | |
| su_norm = torch.sqrt((su_r ** 2 + su_i ** 2).sum(dim=1, keepdim=True).clamp(min=1e-12)) | |
| p_norm = torch.sqrt((p_r ** 2 + p_i ** 2).sum(dim=1, keepdim=True).clamp(min=1e-12)) | |
| LAS = ((inner / (su_norm * p_norm + 1e-12)) + 1.0) / 2.0 # in [0,1] | |
| # combine spec_score and LAS into gamma multiplier | |
| gamma = self.spec_alpha + (1.0 - self.spec_alpha) * (spec_score * LAS) # (K,1) | |
| # depth attenuation | |
| if self.layer_linear_attenuation: | |
| depth_scale = 1.0 - (step / max(1, self.meta_steps - 1)) * 0.5 | |
| else: | |
| depth_scale = 1.0 | |
| gamma = gamma * depth_scale | |
| gamma = gamma.clamp(self.spec_alpha, 1.0) | |
| # adaptive recursion depth R based on LAS | |
| R = torch.ceil(self.R_max * (1.0 - LAS.pow(self.beta))).long().clamp(min=1) | |
| # iterate recursion R times (per-target R; execute max and mask) | |
| max_R = int(R.max().item()) | |
| # precompute per-target gated updates | |
| gated = gate * gamma # (K,1) | |
| d_phase_full = d_phase * gated | |
| d_amp_full = d_amp * gated | |
| d_bs_full = d_bs * gated | |
| # apply recursion with decaying step | |
| for r in range(max_R): | |
| # compute decay per-target | |
| decay = (self.decay_factor ** r) | |
| mask = (R > r).float() # (K,1) | |
| step_scale = decay * mask # (K,1) | |
| phase = phase + 0.01 * (step_scale * d_phase_full) | |
| amp = amp + 0.01 * (step_scale * d_amp_full) | |
| bs_logits = bs_logits + 0.05 * (step_scale * d_bs_full) | |
| # optional: you could recompute d_* based on updated internal state for full recursion semantics | |
| # keep simple deterministic recursion here for stability | |
| return {'phase': phase, 'amp': amp, 'bspline_logits': bs_logits} | |
| class DiscreteBSplineAGIModuleWithMeta(nn.Module): | |
| def __init__(self, input_dim, hidden_dim=128, num_layers=8, | |
| num_nodes=6, window_size=5, angle_bins=16, softmax_temp=0.5, | |
| straight_through_hard=False, | |
| meta_target_node_indices=None, meta_hidden=128, meta_steps=3, | |
| num_meta_controllers=1, use_meta_layernorm=True, | |
| # advanced hyperparams (exposed) | |
| spec_k=4, spec_alpha=0.1, R_max=3, beta=2.0, decay_factor=0.6): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_layers = num_layers | |
| self.num_nodes = num_nodes | |
| self.window_size = window_size | |
| self.angle_bins = angle_bins | |
| self.softmax_temp = float(softmax_temp) | |
| self.straight_through = straight_through_hard | |
| self.input_proj = nn.Linear(input_dim, hidden_dim) | |
| self.bspline_logits = nn.Parameter(torch.randn(num_nodes, window_size) * 0.1) | |
| self.node_rot_real = nn.Parameter(torch.randn(num_nodes, hidden_dim) * 0.01) | |
| self.node_rot_imag = nn.Parameter(torch.randn(num_nodes, hidden_dim) * 0.01) | |
| self.node_amp = nn.Parameter(torch.ones(num_nodes, hidden_dim) * 0.5) | |
| self.routing_prototypes_real = nn.Parameter(torch.randn(num_nodes, hidden_dim) * 0.01) | |
| self.routing_prototypes_imag = nn.Parameter(torch.randn(num_nodes, hidden_dim) * 0.01) | |
| self.tsl_weights_real = nn.ParameterList() | |
| self.tsl_weights_imag = nn.ParameterList() | |
| self.tsl_bias_real = nn.ParameterList() | |
| self.tsl_bias_imag = nn.ParameterList() | |
| for _ in range(num_layers): | |
| self.tsl_weights_real.append(nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.01)) | |
| self.tsl_weights_imag.append(nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.01)) | |
| self.tsl_bias_real.append(nn.Parameter(torch.zeros(hidden_dim))) | |
| self.tsl_bias_imag.append(nn.Parameter(torch.zeros(hidden_dim))) | |
| self.layer_phase = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(num_layers)]) | |
| self.output_proj = nn.Linear(hidden_dim, input_dim) | |
| self.register_buffer("angle_thresholds", torch.linspace(-np.pi, np.pi, angle_bins)) | |
| if meta_target_node_indices is None: | |
| meta_target_node_indices = list(range(num_nodes)) | |
| self.meta_target_node_indices = meta_target_node_indices | |
| self.num_meta_controllers = num_meta_controllers | |
| self.meta_controllers = nn.ModuleList() | |
| node_feat_dim = hidden_dim | |
| for _ in range(num_meta_controllers): | |
| self.meta_controllers.append( | |
| MetaController(node_feature_dim=node_feat_dim, | |
| hidden_dim=meta_hidden, | |
| window_size=window_size, | |
| num_targets=len(meta_target_node_indices), | |
| meta_steps=meta_steps, | |
| use_layernorm=use_meta_layernorm, | |
| spec_k=spec_k, | |
| spec_alpha=spec_alpha, | |
| R_max=R_max, | |
| beta=beta, | |
| decay_factor=decay_factor) | |
| ) | |
| def _to_complex(self, real, imag): | |
| return torch.complex(real, imag) | |
| def _complex_matmul(self, x, wr, wi): | |
| xr = x.real; xi = x.imag | |
| real = xr @ wr - xi @ wi | |
| imag = xr @ wi + xi @ wr | |
| return torch.complex(real, imag) | |
| def _commutator_weight(self, A, B): | |
| A_norm = torch.norm(A, dim=1, keepdim=True).clamp(min=1e-8) | |
| B_norm = torch.norm(B, dim=1, keepdim=True).clamp(min=1e-8) | |
| An = A / A_norm; Bn = B / B_norm | |
| dot = torch.sum(An * torch.conj(Bn), dim=1, keepdim=True) | |
| return 1.0 - torch.clamp(torch.abs(dot).real, 0.0, 1.0) | |
| def _perp_validator(self, x, target_length=1.0): | |
| real = torch.randn_like(x.real) | |
| imag = torch.randn_like(x.real) | |
| rand = torch.complex(real, imag) | |
| x_norm_sq = (torch.sum(x * torch.conj(x), dim=1, keepdim=True)).real.clamp(min=1e-8) | |
| proj = (torch.sum(rand * torch.conj(x), dim=1, keepdim=True) / x_norm_sq) * x | |
| perp = rand - proj | |
| pn = torch.norm(perp, dim=1, keepdim=True).clamp(min=1e-8) | |
| return (target_length / pn) * perp | |
| def forward(self, x, mask=None): | |
| device = x.device | |
| batch, seq_len, _ = x.shape | |
| temp = max(1e-6, self.softmax_temp) | |
| # project to complex hidden (safe device/dtype) | |
| hid_real = self.input_proj(x) | |
| hid_imag = hid_real.new_zeros(hid_real.shape) | |
| x_complex = torch.complex(hid_real, hid_imag) | |
| # prepare node params | |
| node_rot_real = self.node_rot_real.to(device) | |
| node_rot_imag = self.node_rot_imag.to(device) | |
| node_amp = self.node_amp.to(device) | |
| bs_logits = self.bspline_logits.to(device) | |
| node_phase = torch.atan2(node_rot_imag, node_rot_real.clamp_min(1e-6)) | |
| # curvature proxy | |
| if self.hidden_dim >= 3: | |
| phase_diff = torch.diff(node_phase, n=2, dim=1) | |
| curvature = torch.mean(torch.abs(phase_diff), dim=1, keepdim=True) | |
| else: | |
| curvature = torch.zeros((self.num_nodes, 1), device=device, dtype=node_phase.dtype) | |
| prot = torch.complex(self.routing_prototypes_real.to(device), self.routing_prototypes_imag.to(device)) | |
| node_feats = prot.real # (N,H) | |
| target_idxs = torch.tensor(self.meta_target_node_indices, dtype=torch.long, device=device) | |
| targ_phase = node_phase[target_idxs] | |
| targ_amp = node_amp[target_idxs] | |
| targ_bs = bs_logits[target_idxs] | |
| targ_feats = node_feats[target_idxs] | |
| targ_curv = curvature[target_idxs] | |
| params = {'phase': targ_phase, 'amp': targ_amp, 'bspline_logits': targ_bs} | |
| for ctrl in self.meta_controllers: | |
| params = ctrl(targ_feats.to(device), targ_curv.to(device), params) | |
| adj_phase = params['phase'] | |
| adj_amp = params['amp'] | |
| adj_bs = params['bspline_logits'] | |
| # assemble full tensors and reconstruct rotations | |
| node_phase_full = node_phase.clone() | |
| node_amp_full = node_amp.clone() | |
| bs_logits_full = bs_logits.clone() | |
| node_phase_full[target_idxs] = adj_phase | |
| node_amp_full[target_idxs] = adj_amp | |
| bs_logits_full[target_idxs] = adj_bs | |
| node_rot = node_amp_full * torch.exp(1j * node_phase_full) | |
| bs_w = F.softmax(bs_logits_full / temp, dim=1) | |
| # Vectorized window aggregation (device/dtype-safe) | |
| B, T, H = x_complex.shape | |
| N, W = bs_w.shape | |
| half = W // 2 | |
| pad = half | |
| x_p = F.pad(x_complex.permute(0, 2, 1), (pad, pad)) # (B,H,T+2*pad) | |
| windows = x_p.unfold(dimension=2, size=W, step=1) # (B,H,T,W) | |
| windows = windows.permute(0, 2, 1, 3).contiguous() # (B,T,H,W) | |
| ws = torch.einsum('bthw,nw->bthn', windows, bs_w.to(windows.real.dtype, device=windows.device)) | |
| ws = ws.permute(0, 3, 1, 2).contiguous() # (B,N,T,H) | |
| node_rot_b = node_rot.unsqueeze(0).unsqueeze(2) # (1,N,1,H) | |
| node_contribs_all = ws * node_rot_b # (B,N,T,H) complex | |
| prot_b = prot.unsqueeze(0).unsqueeze(2) # (1,N,1,H) | |
| dot = torch.sum(node_contribs_all * torch.conj(prot_b), dim=-1).abs() # (B,N,T) | |
| sims = dot | |
| node_sel_soft = F.softmax(sims / temp, dim=1) | |
| if self.straight_through: | |
| hard_idx = torch.argmax(sims.permute(0,2,1).reshape(batch*seq_len, self.num_nodes), dim=1) | |
| hard_idx = hard_idx.view(batch, seq_len) | |
| node_sel_hard = F.one_hot(hard_idx, num_classes=self.num_nodes).permute(0,2,1).float() | |
| node_sel = (node_sel_hard - node_sel_soft).detach() + node_sel_soft | |
| else: | |
| node_sel = node_sel_soft | |
| node_sel_exp = node_sel.permute(0, 2, 1).unsqueeze(-1) | |
| contribs_perm = node_contribs_all.permute(0,2,1,3) | |
| combined_all = torch.sum(contribs_perm * node_sel_exp, dim=2) | |
| out = combined_all # complex tensor (B,T,H) | |
| Nstate = batch * seq_len | |
| state = out.reshape(Nstate, self.hidden_dim) | |
| state = torch.complex(state, state.new_zeros(state.shape)) | |
| for idx in range(self.num_layers): | |
| wr = self.tsl_weights_real[idx].to(device) | |
| wi = self.tsl_weights_imag[idx].to(device) | |
| br = self.tsl_bias_real[idx].to(device) | |
| bi = self.tsl_bias_imag[idx].to(device) | |
| out_state = self._complex_matmul(state, wr, wi) + torch.complex(br, bi) | |
| phase = self.layer_phase[idx].to(out_state.dtype).to(device) | |
| out_state = out_state * torch.exp(1j * phase) | |
| cw = self._commutator_weight(state, out_state) | |
| state = out_state + cw * state | |
| if idx % 4 == 0 and idx > 0: | |
| perp = self._perp_validator(state, target_length=0.05) | |
| state = state + 0.01 * perp | |
| state = state.reshape(batch, seq_len, self.hidden_dim) | |
| out_mag = torch.abs(state) | |
| out_real = self.output_proj(out_mag) | |
| return out_real | |
| # Quick smoke test | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| m = DiscreteBSplineAGIModuleWithMeta(input_dim=64, hidden_dim=64, num_layers=6, | |
| num_nodes=6, window_size=5, softmax_temp=0.7, | |
| straight_through_hard=True, | |
| meta_target_node_indices=[1,3,4], | |
| meta_hidden=64, meta_steps=2, | |
| num_meta_controllers=2, | |
| spec_k=3, spec_alpha=0.08, R_max=3, | |
| beta=2.0, decay_factor=0.7).to(device) | |
| x = torch.randn(2, 16, 64, device=device) | |
| y = m(x) | |
| print("in:", x.shape, "out:", y.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment