Created
April 3, 2026 08:55
-
-
Save wassname/2548570298bacd982f847c581a1ceb40 to your computer and use it in GitHub Desktop.
classify_linear_sublayers as residual read or write
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def classify_linear_sublayers( | |
| model, | |
| block_layers: list[str], | |
| ) -> dict[str, list[str]]: | |
| """Classify all Linear sublayers in each block by their residual-stream role. | |
| For a Linear layer with weight shape [out_features, in_features]: | |
| - residual_write : out_features == d_model (writes TO residual stream) | |
| - residual_read : in_features == d_model (reads FROM residual stream) | |
| - ambiguous : square [d_model, d_model] — resolved by name pattern: | |
| o_proj / out_proj / down_proj / d_proj → write | |
| q_proj / k_proj / v_proj / up_proj → read | |
| everything else → ambiguous_unknown | |
| Returns dict with keys: 'write', 'read', 'ambiguous_write', 'ambiguous_read', 'ambiguous_unknown'. | |
| """ | |
| WRITE_PATTERNS = ("o_proj", "out_proj", "down_proj", "d_proj") | |
| READ_PATTERNS = ("q_proj", "k_proj", "v_proj", "up_proj", "gate_proj", | |
| "in_proj", "qkv", "query", "key", "value") | |
| # Infer d_model from the model config | |
| cfg = model.config | |
| d_model = getattr(cfg, "hidden_size", None) or getattr(cfg, "d_model", None) | |
| assert d_model, f"Cannot infer d_model from model.config: {cfg}" | |
| result: dict[str, list[str]] = { | |
| "write": [], "read": [], | |
| "ambiguous_write": [], "ambiguous_read": [], "ambiguous_unknown": [], | |
| } | |
| for block in block_layers: | |
| try: | |
| block_module = model.get_submodule(block) | |
| except AttributeError: | |
| continue | |
| for name, m in block_module.named_modules(): | |
| if not isinstance(m, torch.nn.Linear): | |
| continue | |
| full_name = f"{block}.{name}" | |
| out_f, in_f = m.weight.shape | |
| if out_f == in_f == d_model: | |
| # Ambiguous square — resolve by name | |
| if any(p in name for p in WRITE_PATTERNS): | |
| result["ambiguous_write"].append(full_name) | |
| elif any(p in name for p in READ_PATTERNS): | |
| result["ambiguous_read"].append(full_name) | |
| else: | |
| result["ambiguous_unknown"].append(full_name) | |
| elif out_f == d_model: | |
| result["write"].append(full_name) | |
| elif in_f == d_model: | |
| result["read"].append(full_name) | |
| # else: neither dim matches d_model (e.g. head projections) — ignore | |
| return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment