Skip to content

Instantly share code, notes, and snippets.

@wassname
Created April 3, 2026 08:55
Show Gist options
  • Select an option

  • Save wassname/2548570298bacd982f847c581a1ceb40 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/2548570298bacd982f847c581a1ceb40 to your computer and use it in GitHub Desktop.
classify_linear_sublayers as residual read or write
def classify_linear_sublayers(
model,
block_layers: list[str],
) -> dict[str, list[str]]:
"""Classify all Linear sublayers in each block by their residual-stream role.
For a Linear layer with weight shape [out_features, in_features]:
- residual_write : out_features == d_model (writes TO residual stream)
- residual_read : in_features == d_model (reads FROM residual stream)
- ambiguous : square [d_model, d_model] — resolved by name pattern:
o_proj / out_proj / down_proj / d_proj → write
q_proj / k_proj / v_proj / up_proj → read
everything else → ambiguous_unknown
Returns dict with keys: 'write', 'read', 'ambiguous_write', 'ambiguous_read', 'ambiguous_unknown'.
"""
WRITE_PATTERNS = ("o_proj", "out_proj", "down_proj", "d_proj")
READ_PATTERNS = ("q_proj", "k_proj", "v_proj", "up_proj", "gate_proj",
"in_proj", "qkv", "query", "key", "value")
# Infer d_model from the model config
cfg = model.config
d_model = getattr(cfg, "hidden_size", None) or getattr(cfg, "d_model", None)
assert d_model, f"Cannot infer d_model from model.config: {cfg}"
result: dict[str, list[str]] = {
"write": [], "read": [],
"ambiguous_write": [], "ambiguous_read": [], "ambiguous_unknown": [],
}
for block in block_layers:
try:
block_module = model.get_submodule(block)
except AttributeError:
continue
for name, m in block_module.named_modules():
if not isinstance(m, torch.nn.Linear):
continue
full_name = f"{block}.{name}"
out_f, in_f = m.weight.shape
if out_f == in_f == d_model:
# Ambiguous square — resolve by name
if any(p in name for p in WRITE_PATTERNS):
result["ambiguous_write"].append(full_name)
elif any(p in name for p in READ_PATTERNS):
result["ambiguous_read"].append(full_name)
else:
result["ambiguous_unknown"].append(full_name)
elif out_f == d_model:
result["write"].append(full_name)
elif in_f == d_model:
result["read"].append(full_name)
# else: neither dim matches d_model (e.g. head projections) — ignore
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment