Skip to content

Instantly share code, notes, and snippets.

@lastforkbender
Created May 16, 2026 02:42
Show Gist options
  • Select an option

  • Save lastforkbender/59b6754c4d9a008b391c36635399b4b3 to your computer and use it in GitHub Desktop.

Select an option

Save lastforkbender/59b6754c4d9a008b391c36635399b4b3 to your computer and use it in GitHub Desktop.
BNN-AGI with lower/upper complex number inference
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
class MetaController(nn.Module):
def __init__(self, node_feature_dim, hidden_dim, window_size, num_targets,
meta_steps=3, use_layernorm=True,
spec_k=4, spec_alpha=0.1, layer_linear_attenuation=True,
alpha_likeness=0.05, R_max=3, beta=2.0, decay_factor=0.6):
super().__init__()
self.node_feature_dim = node_feature_dim
self.hidden_dim = hidden_dim
self.window_size = window_size
self.num_targets = num_targets
self.meta_steps = meta_steps
enc_layers = [
nn.Linear(node_feature_dim + 1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
]
if use_layernorm:
enc_layers.append(nn.LayerNorm(hidden_dim))
self.encoder = nn.Sequential(*enc_layers)
self.delta_phase_head = nn.Linear(hidden_dim, hidden_dim)
self.delta_amp_head = nn.Linear(hidden_dim, hidden_dim)
self.delta_bspline_head = nn.Linear(hidden_dim, window_size)
self.gate = nn.Linear(hidden_dim, 1)
# Complex projection params (real/imag pairs implemented as real tensors)
h = hidden_dim
self.Wu_r = nn.Parameter(torch.randn(h, h) * 0.01)
self.Wu_i = nn.Parameter(torch.randn(h, h) * 0.01)
self.Wl_r = nn.Parameter(torch.randn(h, h) * 0.01)
self.Wl_i = nn.Parameter(torch.randn(h, h) * 0.01)
self.Pr = nn.Parameter(torch.randn(h, h) * 0.01)
self.Pi = nn.Parameter(torch.randn(h, h) * 0.01)
# control hyperparams
self.spec_k = spec_k
self.spec_alpha = spec_alpha
self.layer_linear_attenuation = layer_linear_attenuation
self.alpha_likeness = alpha_likeness
self.R_max = R_max
self.beta = beta
self.decay_factor = decay_factor
def _complex_matmul_pair(self, x_real, x_imag, wr, wi):
# x: (K,H), wr/wi: (H,h) -> out (K,h) real/imag
real = x_real @ wr - x_imag @ wi
imag = x_real @ wi + x_imag @ wr
return real, imag
def _spectral_score(self, t, axis=1, k=None, eps=1e-12):
# t: real tensor (..., L)
if k is None:
k = self.spec_k
# ensure last dim is frequency dim
orig_dims = t.ndim
if axis != -1:
t = t.movedim(axis, -1)
# rfft over last dim
spec = torch.fft.rfft(t, dim=-1)
psd = (spec.abs() ** 2)
# sum over non-frequency dims to get per-item energy across freq
# shape (..., Freq)
low = psd[..., :k].sum(dim=tuple(range(psd.ndim - 1))) if psd.ndim > 1 else psd[..., :k].sum()
total = psd.sum(dim=tuple(range(psd.ndim - 1))) + eps if psd.ndim > 1 else psd.sum() + eps
# If low/total are scalars, convert to tensor
if not torch.is_tensor(low):
low = torch.tensor(low, device=t.device, dtype=t.dtype)
total = torch.tensor(total, device=t.device, dtype=t.dtype)
score = (low / total).clamp(0.0, 1.0)
# return as shape broadcastable (...,1)
return score.unsqueeze(-1)
def forward(self, node_feats, curvature, params):
# node_feats: (K, H), curvature: (K,1)
phase = params['phase'] # (K, H)
amp = params['amp'] # (K, H)
bs_logits = params['bspline_logits'] # (K, W)
# ensure real tensors where expected
device = node_feats.device
dtype = node_feats.dtype
for step in range(self.meta_steps):
enc_in = torch.cat([node_feats, curvature], dim=1)
h = self.encoder(enc_in)
gate = torch.sigmoid(self.gate(h)) # (K,1)
d_phase = self.delta_phase_head(h) # (K,H) real
d_amp = self.delta_amp_head(h)
d_bs = self.delta_bspline_head(h) # (K,W)
# spectral pre-score (use magnitude across hidden dim for phase/amp)
sp_p = self._spectral_score(d_phase.abs(), axis=1)
sp_a = self._spectral_score(d_amp.abs(), axis=1)
sp_b = self._spectral_score(d_bs.abs(), axis=1)
spec_score = (sp_p + sp_a + sp_b) / 3.0 # (K,1)
# UCS/LCS projections and LAS
# Build complex node_rot proxies from phase & amp for targets: s = amp * exp(i*phase)
s_real = amp * torch.cos(phase)
s_imag = amp * torch.sin(phase)
# project to UCS and LCS using real/imag paired mats
su_r, su_i = self._complex_matmul_pair(s_real, s_imag, self.Wu_r, self.Wu_i)
sl_r, sl_i = self._complex_matmul_pair(s_real, s_imag, self.Wl_r, self.Wl_i)
# map sl -> p via P
p_r, p_i = self._complex_matmul_pair(sl_r, sl_i, self.Pr, self.Pi)
# likeness LAS per target: real inner product between su and conj(p) normalized
# inner = sum_k Re( su_k * conj(p_k) ) = sum( su_r*p_r + su_i*p_i )
inner = torch.sum(su_r * p_r + su_i * p_i, dim=1, keepdim=True) # (K,1)
su_norm = torch.sqrt((su_r ** 2 + su_i ** 2).sum(dim=1, keepdim=True).clamp(min=1e-12))
p_norm = torch.sqrt((p_r ** 2 + p_i ** 2).sum(dim=1, keepdim=True).clamp(min=1e-12))
LAS = ((inner / (su_norm * p_norm + 1e-12)) + 1.0) / 2.0 # in [0,1]
# combine spec_score and LAS into gamma multiplier
gamma = self.spec_alpha + (1.0 - self.spec_alpha) * (spec_score * LAS) # (K,1)
# depth attenuation
if self.layer_linear_attenuation:
depth_scale = 1.0 - (step / max(1, self.meta_steps - 1)) * 0.5
else:
depth_scale = 1.0
gamma = gamma * depth_scale
gamma = gamma.clamp(self.spec_alpha, 1.0)
# adaptive recursion depth R based on LAS
R = torch.ceil(self.R_max * (1.0 - LAS.pow(self.beta))).long().clamp(min=1)
# iterate recursion R times (per-target R; execute max and mask)
max_R = int(R.max().item())
# precompute per-target gated updates
gated = gate * gamma # (K,1)
d_phase_full = d_phase * gated
d_amp_full = d_amp * gated
d_bs_full = d_bs * gated
# apply recursion with decaying step
for r in range(max_R):
# compute decay per-target
decay = (self.decay_factor ** r)
mask = (R > r).float() # (K,1)
step_scale = decay * mask # (K,1)
phase = phase + 0.01 * (step_scale * d_phase_full)
amp = amp + 0.01 * (step_scale * d_amp_full)
bs_logits = bs_logits + 0.05 * (step_scale * d_bs_full)
# optional: you could recompute d_* based on updated internal state for full recursion semantics
# keep simple deterministic recursion here for stability
return {'phase': phase, 'amp': amp, 'bspline_logits': bs_logits}
class DiscreteBSplineAGIModuleWithMeta(nn.Module):
def __init__(self, input_dim, hidden_dim=128, num_layers=8,
num_nodes=6, window_size=5, angle_bins=16, softmax_temp=0.5,
straight_through_hard=False,
meta_target_node_indices=None, meta_hidden=128, meta_steps=3,
num_meta_controllers=1, use_meta_layernorm=True,
# advanced hyperparams (exposed)
spec_k=4, spec_alpha=0.1, R_max=3, beta=2.0, decay_factor=0.6):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_nodes = num_nodes
self.window_size = window_size
self.angle_bins = angle_bins
self.softmax_temp = float(softmax_temp)
self.straight_through = straight_through_hard
self.input_proj = nn.Linear(input_dim, hidden_dim)
self.bspline_logits = nn.Parameter(torch.randn(num_nodes, window_size) * 0.1)
self.node_rot_real = nn.Parameter(torch.randn(num_nodes, hidden_dim) * 0.01)
self.node_rot_imag = nn.Parameter(torch.randn(num_nodes, hidden_dim) * 0.01)
self.node_amp = nn.Parameter(torch.ones(num_nodes, hidden_dim) * 0.5)
self.routing_prototypes_real = nn.Parameter(torch.randn(num_nodes, hidden_dim) * 0.01)
self.routing_prototypes_imag = nn.Parameter(torch.randn(num_nodes, hidden_dim) * 0.01)
self.tsl_weights_real = nn.ParameterList()
self.tsl_weights_imag = nn.ParameterList()
self.tsl_bias_real = nn.ParameterList()
self.tsl_bias_imag = nn.ParameterList()
for _ in range(num_layers):
self.tsl_weights_real.append(nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.01))
self.tsl_weights_imag.append(nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.01))
self.tsl_bias_real.append(nn.Parameter(torch.zeros(hidden_dim)))
self.tsl_bias_imag.append(nn.Parameter(torch.zeros(hidden_dim)))
self.layer_phase = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(num_layers)])
self.output_proj = nn.Linear(hidden_dim, input_dim)
self.register_buffer("angle_thresholds", torch.linspace(-np.pi, np.pi, angle_bins))
if meta_target_node_indices is None:
meta_target_node_indices = list(range(num_nodes))
self.meta_target_node_indices = meta_target_node_indices
self.num_meta_controllers = num_meta_controllers
self.meta_controllers = nn.ModuleList()
node_feat_dim = hidden_dim
for _ in range(num_meta_controllers):
self.meta_controllers.append(
MetaController(node_feature_dim=node_feat_dim,
hidden_dim=meta_hidden,
window_size=window_size,
num_targets=len(meta_target_node_indices),
meta_steps=meta_steps,
use_layernorm=use_meta_layernorm,
spec_k=spec_k,
spec_alpha=spec_alpha,
R_max=R_max,
beta=beta,
decay_factor=decay_factor)
)
def _to_complex(self, real, imag):
return torch.complex(real, imag)
def _complex_matmul(self, x, wr, wi):
xr = x.real; xi = x.imag
real = xr @ wr - xi @ wi
imag = xr @ wi + xi @ wr
return torch.complex(real, imag)
def _commutator_weight(self, A, B):
A_norm = torch.norm(A, dim=1, keepdim=True).clamp(min=1e-8)
B_norm = torch.norm(B, dim=1, keepdim=True).clamp(min=1e-8)
An = A / A_norm; Bn = B / B_norm
dot = torch.sum(An * torch.conj(Bn), dim=1, keepdim=True)
return 1.0 - torch.clamp(torch.abs(dot).real, 0.0, 1.0)
def _perp_validator(self, x, target_length=1.0):
real = torch.randn_like(x.real)
imag = torch.randn_like(x.real)
rand = torch.complex(real, imag)
x_norm_sq = (torch.sum(x * torch.conj(x), dim=1, keepdim=True)).real.clamp(min=1e-8)
proj = (torch.sum(rand * torch.conj(x), dim=1, keepdim=True) / x_norm_sq) * x
perp = rand - proj
pn = torch.norm(perp, dim=1, keepdim=True).clamp(min=1e-8)
return (target_length / pn) * perp
def forward(self, x, mask=None):
device = x.device
batch, seq_len, _ = x.shape
temp = max(1e-6, self.softmax_temp)
# project to complex hidden (safe device/dtype)
hid_real = self.input_proj(x)
hid_imag = hid_real.new_zeros(hid_real.shape)
x_complex = torch.complex(hid_real, hid_imag)
# prepare node params
node_rot_real = self.node_rot_real.to(device)
node_rot_imag = self.node_rot_imag.to(device)
node_amp = self.node_amp.to(device)
bs_logits = self.bspline_logits.to(device)
node_phase = torch.atan2(node_rot_imag, node_rot_real.clamp_min(1e-6))
# curvature proxy
if self.hidden_dim >= 3:
phase_diff = torch.diff(node_phase, n=2, dim=1)
curvature = torch.mean(torch.abs(phase_diff), dim=1, keepdim=True)
else:
curvature = torch.zeros((self.num_nodes, 1), device=device, dtype=node_phase.dtype)
prot = torch.complex(self.routing_prototypes_real.to(device), self.routing_prototypes_imag.to(device))
node_feats = prot.real # (N,H)
target_idxs = torch.tensor(self.meta_target_node_indices, dtype=torch.long, device=device)
targ_phase = node_phase[target_idxs]
targ_amp = node_amp[target_idxs]
targ_bs = bs_logits[target_idxs]
targ_feats = node_feats[target_idxs]
targ_curv = curvature[target_idxs]
params = {'phase': targ_phase, 'amp': targ_amp, 'bspline_logits': targ_bs}
for ctrl in self.meta_controllers:
params = ctrl(targ_feats.to(device), targ_curv.to(device), params)
adj_phase = params['phase']
adj_amp = params['amp']
adj_bs = params['bspline_logits']
# assemble full tensors and reconstruct rotations
node_phase_full = node_phase.clone()
node_amp_full = node_amp.clone()
bs_logits_full = bs_logits.clone()
node_phase_full[target_idxs] = adj_phase
node_amp_full[target_idxs] = adj_amp
bs_logits_full[target_idxs] = adj_bs
node_rot = node_amp_full * torch.exp(1j * node_phase_full)
bs_w = F.softmax(bs_logits_full / temp, dim=1)
# Vectorized window aggregation (device/dtype-safe)
B, T, H = x_complex.shape
N, W = bs_w.shape
half = W // 2
pad = half
x_p = F.pad(x_complex.permute(0, 2, 1), (pad, pad)) # (B,H,T+2*pad)
windows = x_p.unfold(dimension=2, size=W, step=1) # (B,H,T,W)
windows = windows.permute(0, 2, 1, 3).contiguous() # (B,T,H,W)
ws = torch.einsum('bthw,nw->bthn', windows, bs_w.to(windows.real.dtype, device=windows.device))
ws = ws.permute(0, 3, 1, 2).contiguous() # (B,N,T,H)
node_rot_b = node_rot.unsqueeze(0).unsqueeze(2) # (1,N,1,H)
node_contribs_all = ws * node_rot_b # (B,N,T,H) complex
prot_b = prot.unsqueeze(0).unsqueeze(2) # (1,N,1,H)
dot = torch.sum(node_contribs_all * torch.conj(prot_b), dim=-1).abs() # (B,N,T)
sims = dot
node_sel_soft = F.softmax(sims / temp, dim=1)
if self.straight_through:
hard_idx = torch.argmax(sims.permute(0,2,1).reshape(batch*seq_len, self.num_nodes), dim=1)
hard_idx = hard_idx.view(batch, seq_len)
node_sel_hard = F.one_hot(hard_idx, num_classes=self.num_nodes).permute(0,2,1).float()
node_sel = (node_sel_hard - node_sel_soft).detach() + node_sel_soft
else:
node_sel = node_sel_soft
node_sel_exp = node_sel.permute(0, 2, 1).unsqueeze(-1)
contribs_perm = node_contribs_all.permute(0,2,1,3)
combined_all = torch.sum(contribs_perm * node_sel_exp, dim=2)
out = combined_all # complex tensor (B,T,H)
Nstate = batch * seq_len
state = out.reshape(Nstate, self.hidden_dim)
state = torch.complex(state, state.new_zeros(state.shape))
for idx in range(self.num_layers):
wr = self.tsl_weights_real[idx].to(device)
wi = self.tsl_weights_imag[idx].to(device)
br = self.tsl_bias_real[idx].to(device)
bi = self.tsl_bias_imag[idx].to(device)
out_state = self._complex_matmul(state, wr, wi) + torch.complex(br, bi)
phase = self.layer_phase[idx].to(out_state.dtype).to(device)
out_state = out_state * torch.exp(1j * phase)
cw = self._commutator_weight(state, out_state)
state = out_state + cw * state
if idx % 4 == 0 and idx > 0:
perp = self._perp_validator(state, target_length=0.05)
state = state + 0.01 * perp
state = state.reshape(batch, seq_len, self.hidden_dim)
out_mag = torch.abs(state)
out_real = self.output_proj(out_mag)
return out_real
# Quick smoke test
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
m = DiscreteBSplineAGIModuleWithMeta(input_dim=64, hidden_dim=64, num_layers=6,
num_nodes=6, window_size=5, softmax_temp=0.7,
straight_through_hard=True,
meta_target_node_indices=[1,3,4],
meta_hidden=64, meta_steps=2,
num_meta_controllers=2,
spec_k=3, spec_alpha=0.08, R_max=3,
beta=2.0, decay_factor=0.7).to(device)
x = torch.randn(2, 16, 64, device=device)
y = m(x)
print("in:", x.shape, "out:", y.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment