Forked from VikashLoomba/requirements_gemma4_additive_moe_reviewed_v10_b300_24h.txt
Created
May 29, 2026 19:45
-
-
Save ichim-david/7aedb44c8122606802094a938a9e36b2 to your computer and use it in GitHub Desktop.
Gemma4 additive MoE reviewed v10 B300 24h training script and requirements
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Install PyTorch separately for the B300 / Blackwell CUDA stack, for example: | |
| # pip install --pre --upgrade torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 | |
| # | |
| # Then install this file: | |
| # pip install --upgrade -r requirements_gemma4_additive_moe_reviewed_v9_streamfix.txt | |
| # | |
| # FlashAttention 2 is optional. The script defaults to SDPA for maximum | |
| # out-of-the-box compatibility. To use FlashAttention 2 after installing it: | |
| # pass --attn_implementation flash_attention_2. | |
| # To install FlashAttention explicitly: | |
| # pip install flash-attn --no-build-isolation | |
| transformers @ git+https://github.com/huggingface/transformers.git | |
| datasets>=3.0.0 | |
| accelerate>=1.0.0 | |
| safetensors>=0.4.5 | |
| huggingface_hub>=0.30.0 | |
| hf_transfer>=0.1.9 | |
| tqdm>=4.66.0 | |
| sentencepiece>=0.2.0 | |
| protobuf>=4.25.0 | |
| packaging>=24.0 | |
| ninja>=1.11.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| End-to-end additive-MoE upcycling for google/gemma-4-31B-it. | |
| What this script does: | |
| 1. Loads the dense google/gemma-4-31B-it checkpoint. | |
| 2. Mutates its Gemma 4 text config to enable the built-in additive MoE block. | |
| 3. Adds fresh routers, expert banks, and MoE-specific norms through the native | |
| Transformers Gemma4 implementation. | |
| 4. Freezes the original dense/multimodal backbone. | |
| 5. Trains only router/expert/MoE-norm parameters on a streamed HF data mixture. | |
| 6. Applies model-specific stabilization so the new full-width additive branch | |
| starts near-zero instead of corrupting the dense model during free generation. | |
| 7. Validates that checkpoint loading only initialized the new MoE keys and that | |
| the dense backbone is frozen. | |
| 8. Saves a full Hugging Face model directory. The default dense-preserving | |
| bridge writes a small custom model file, so inference should load with | |
| trust_remote_code=True. | |
| Designed default target: single NVIDIA B300-class GPU, frozen bf16 backbone, | |
| fp32 trainable MoE parameters, PyTorch SDPA attention by default, sequence | |
| length 4096, 8 experts, top-2 routing, expert intermediate size 512. | |
| Revision v10-b300-24h: keeps the v9 robust streaming mixer, adds explicit | |
| single-B300 wall-clock/token-budget controls, prevents accidental no-checkpoint | |
| OOM relaunches on one GPU, and adds a high-signal 24h data-mix preset. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import contextlib | |
| import dataclasses | |
| import gc | |
| import json | |
| import math | |
| import os | |
| import random | |
| import re | |
| import signal | |
| import threading | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Iterable, Iterator, Optional | |
| from urllib.parse import quote | |
| from urllib.request import Request, urlopen | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import interleave_datasets, load_dataset | |
| from huggingface_hub import HfApi | |
| from torch import nn | |
| from torch.utils.data import DataLoader, IterableDataset, get_worker_info | |
| from tqdm.auto import tqdm | |
| from transformers import AutoConfig, AutoProcessor, AutoTokenizer | |
| try: | |
| # Prefer the exact model class because google/gemma-4-31B-it is a multimodal | |
| # Gemma4ForConditionalGeneration checkpoint, not a plain text-only CausalLM. | |
| from transformers import Gemma4ForConditionalGeneration as Gemma4ModelClass | |
| except Exception: # pragma: no cover - older/local Transformers builds | |
| Gemma4ModelClass = None | |
| try: | |
| from transformers import AutoModelForImageTextToText as AutoModelClass | |
| except Exception: # pragma: no cover - fallback for unusual Transformers builds | |
| from transformers import AutoModelForCausalLM as AutoModelClass | |
| def resolve_model_class(): | |
| return Gemma4ModelClass if Gemma4ModelClass is not None else AutoModelClass | |
| # ----------------------------- | |
| # Dataset mix | |
| # ----------------------------- | |
| @dataclass(frozen=True) | |
| class DatasetSpec: | |
| name: str | |
| split: str | |
| weight: float | |
| kind: str | |
| config: Optional[str | tuple[str, ...]] = None | |
| # Optional explicit loader/data files for repos that are published as raw | |
| # JSONL/text rather than a canonical dataset builder. | |
| loader: Optional[str] = None | |
| data_files: Optional[Any] = None | |
| revision: Optional[str] = None | |
| BASE_DATASET_MIX: tuple[DatasetSpec, ...] = ( | |
| # 2026 knowledge/current public encyclopedia data. | |
| DatasetSpec( | |
| name="wikimedia/structured-wikipedia", | |
| config="enwiki_namespace_0", | |
| split="train", | |
| weight=0.18, | |
| kind="wikipedia", | |
| ), | |
| # Latest clean English Wikipedia dump; the dataset supports latest.en for current monthly refreshes. | |
| DatasetSpec( | |
| name="omarkamali/wikipedia-monthly", | |
| config="latest.en", | |
| split="train", | |
| weight=0.12, | |
| kind="wikipedia_monthly", | |
| ), | |
| # Broad open-licensed background corpus, including code/government/science/legal. | |
| DatasetSpec( | |
| name="PleIAs/common_corpus", | |
| split="train", | |
| weight=0.12, | |
| kind="common_corpus", | |
| ), | |
| # Permissively licensed function-level code corpus. | |
| DatasetSpec( | |
| name="Samip/Scotch", | |
| config="all", | |
| split="train", | |
| weight=0.14, | |
| kind="scotch", | |
| loader="hf_parquet", | |
| ), | |
| # Open-license GitHub issue/PR/comment corpus for SWE/code-discussion style. | |
| DatasetSpec( | |
| name="common-pile/github_archive_filtered", | |
| split="train", | |
| weight=0.06, | |
| kind="github_archive", | |
| ), | |
| # Large Apache-2.0 agentic coding trajectory set in mini-swe-agent format. | |
| DatasetSpec( | |
| name="AlienKevin/SWE-ZERO-12M-trajectories", | |
| split="train", | |
| weight=0.24, | |
| kind="swe_zero", | |
| ), | |
| # Function calling / JSON mode. Weight is split evenly over the five configs. | |
| DatasetSpec( | |
| name="NousResearch/hermes-function-calling-v1", | |
| config=( | |
| "func_calling_singleturn", | |
| "func_calling", | |
| "glaive_func_calling", | |
| "json_mode_agentic", | |
| "json_mode_singleturn", | |
| ), | |
| split="train", | |
| weight=0.08, | |
| kind="hermes_tool", | |
| ), | |
| # On-device function calling; contains 2024-2026 dated examples. | |
| DatasetSpec( | |
| name="json", | |
| split="train", | |
| weight=0.04, | |
| kind="mobile_actions", | |
| loader="json", | |
| data_files="https://huggingface.co/datasets/google/mobile-actions/resolve/main/dataset.jsonl", | |
| ), | |
| # Tool-use decision data: when to call a tool vs answer directly. | |
| DatasetSpec( | |
| name="nvidia/When2Call", | |
| config="train_sft", | |
| split="train", | |
| weight=0.04, | |
| kind="when2call_tool", | |
| ), | |
| # Multi-turn code-agent/tool-use traces from the verified Tau2 airline environment. | |
| # Load through the explicit Parquet artifact instead of the generic repo loader: | |
| # this repo is tiny and Parquet-native/Xet-backed, and the generic streaming | |
| # path can stall while resolving Hub metadata on some training nodes. | |
| DatasetSpec( | |
| name="snorkelai/Tau2-Bench-Verified-Airline-With-Code-Agents", | |
| split="train", | |
| weight=0.08, | |
| kind="tau2_tool", | |
| loader="parquet", | |
| data_files={ | |
| "train": "https://huggingface.co/datasets/snorkelai/Tau2-Bench-Verified-Airline-With-Code-Agents/resolve/main/data/train-00000-of-00001.parquet" | |
| }, | |
| ), | |
| ) | |
| # 24h/high-signal preset: remove broad background common-corpus tokens and | |
| # concentrate the limited token budget on the user's target domains. The frozen | |
| # dense backbone already preserves broad language/model capability; this preset | |
| # spends scarce tokens on current knowledge, code, SWE-agent, and tool-use. | |
| HIGH_SIGNAL_24H_DATASET_MIX: tuple[DatasetSpec, ...] = ( | |
| DatasetSpec( | |
| name="wikimedia/structured-wikipedia", | |
| config="enwiki_namespace_0", | |
| split="train", | |
| weight=0.18, | |
| kind="wikipedia", | |
| ), | |
| DatasetSpec( | |
| name="omarkamali/wikipedia-monthly", | |
| config="latest.en", | |
| split="train", | |
| weight=0.10, | |
| kind="wikipedia_monthly", | |
| ), | |
| DatasetSpec( | |
| name="Samip/Scotch", | |
| config="all", | |
| split="train", | |
| weight=0.15, | |
| kind="scotch", | |
| loader="hf_parquet", | |
| ), | |
| DatasetSpec( | |
| name="common-pile/github_archive_filtered", | |
| split="train", | |
| weight=0.05, | |
| kind="github_archive", | |
| ), | |
| DatasetSpec( | |
| name="AlienKevin/SWE-ZERO-12M-trajectories", | |
| split="train", | |
| weight=0.30, | |
| kind="swe_zero", | |
| ), | |
| DatasetSpec( | |
| name="NousResearch/hermes-function-calling-v1", | |
| config=( | |
| "func_calling_singleturn", | |
| "func_calling", | |
| "glaive_func_calling", | |
| "json_mode_agentic", | |
| "json_mode_singleturn", | |
| ), | |
| split="train", | |
| weight=0.12, | |
| kind="hermes_tool", | |
| ), | |
| DatasetSpec( | |
| name="json", | |
| split="train", | |
| weight=0.04, | |
| kind="mobile_actions", | |
| loader="json", | |
| data_files="https://huggingface.co/datasets/google/mobile-actions/resolve/main/dataset.jsonl", | |
| ), | |
| DatasetSpec( | |
| name="nvidia/When2Call", | |
| config="train_sft", | |
| split="train", | |
| weight=0.04, | |
| kind="when2call_tool", | |
| ), | |
| DatasetSpec( | |
| name="snorkelai/Tau2-Bench-Verified-Airline-With-Code-Agents", | |
| split="train", | |
| weight=0.06, | |
| kind="tau2_tool", | |
| loader="parquet", | |
| data_files={ | |
| "train": "https://huggingface.co/datasets/snorkelai/Tau2-Bench-Verified-Airline-With-Code-Agents/resolve/main/data/train-00000-of-00001.parquet" | |
| }, | |
| ), | |
| ) | |
| # Direct SWE-agent trajectories. Disabled by default because the dataset card | |
| # includes a model-output licensing notice; enable only if your release policy allows it. | |
| OPTIONAL_NEBIUS_SWE_AGENT = DatasetSpec( | |
| name="nebius/SWE-agent-trajectories", | |
| split="train", | |
| weight=0.08, | |
| kind="nebius_swe_agent", | |
| ) | |
| def _parse_kind_csv(value: str | None) -> set[str]: | |
| if not value: | |
| return set() | |
| return {part.strip() for part in value.split(",") if part.strip()} | |
| def build_dataset_mix(args: argparse.Namespace) -> tuple[DatasetSpec, ...]: | |
| preset = getattr(args, "data_mix_preset", "default") | |
| if preset == "high_signal_24h": | |
| specs = list(HIGH_SIGNAL_24H_DATASET_MIX) | |
| elif preset == "default": | |
| specs = list(BASE_DATASET_MIX) | |
| else: | |
| raise ValueError(f"Unknown data_mix_preset={preset!r}") | |
| if not getattr(args, "include_tau2_tool", True): | |
| specs = [s for s in specs if s.kind != "tau2_tool"] | |
| if getattr(args, "include_nebius_swe_agent", False): | |
| specs.append(OPTIONAL_NEBIUS_SWE_AGENT) | |
| only_kinds = _parse_kind_csv(getattr(args, "only_dataset_kinds", "")) | |
| exclude_kinds = _parse_kind_csv(getattr(args, "exclude_dataset_kinds", "")) | |
| if only_kinds: | |
| specs = [s for s in specs if s.kind in only_kinds] | |
| if exclude_kinds: | |
| specs = [s for s in specs if s.kind not in exclude_kinds] | |
| if not specs: | |
| raise ValueError("Dataset mix is empty after applying include/exclude filters.") | |
| return tuple(specs) | |
| ROLE_MAP = { | |
| "human": "user", | |
| "user": "user", | |
| "gpt": "assistant", | |
| "ai": "assistant", | |
| "assistant": "assistant", | |
| "model": "assistant", | |
| "system": "system", | |
| "developer": "system", | |
| "tool": "tool", | |
| "observation": "tool", | |
| "environment": "tool", | |
| } | |
| TEXT_KEYS_PREFERENCE = ( | |
| "text", | |
| "content", | |
| "document", | |
| "article", | |
| "body", | |
| "markdown", | |
| "prompt", | |
| "response", | |
| "completion", | |
| "answer", | |
| "question", | |
| "abstract", | |
| "summary", | |
| "description", | |
| "trajectory", | |
| "trace", | |
| "rollout", | |
| "messages", | |
| "conversations", | |
| "tool_calls", | |
| "tools", | |
| "generated_patch", | |
| "eval_logs", | |
| "function", | |
| "code", | |
| "context", | |
| "docstring", | |
| "metadata", | |
| "user_scenario", | |
| "policy", | |
| "policy_md", | |
| "db_diff", | |
| "db_diff_verbose", | |
| "repo", | |
| "name", | |
| "title", | |
| "url", | |
| ) | |
| # ----------------------------- | |
| # Generic formatting utilities | |
| # ----------------------------- | |
| def _json_loads_maybe(value: Any) -> Any: | |
| if isinstance(value, str): | |
| s = value.strip() | |
| if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")): | |
| try: | |
| return json.loads(s) | |
| except Exception: | |
| return value | |
| return value | |
| def _json_dumps(value: Any) -> str: | |
| try: | |
| return json.dumps(value, ensure_ascii=False, default=str) | |
| except Exception: | |
| return str(value) | |
| def _stringify(value: Any, *, max_chars: int = 80_000) -> str: | |
| value = _json_loads_maybe(value) | |
| if value is None: | |
| return "" | |
| if isinstance(value, str): | |
| return value[:max_chars] | |
| if isinstance(value, (int, float, bool)): | |
| return str(value) | |
| if isinstance(value, list): | |
| parts = [] | |
| for item in value: | |
| s = _stringify(item, max_chars=max_chars) | |
| if s: | |
| parts.append(s) | |
| if sum(len(p) for p in parts) > max_chars: | |
| break | |
| return "\n".join(parts)[:max_chars] | |
| if isinstance(value, dict): | |
| parts = [] | |
| for k, v in value.items(): | |
| if v is None: | |
| continue | |
| sv = _stringify(v, max_chars=max_chars) | |
| if sv: | |
| parts.append(f"{k}: {sv}") | |
| if sum(len(p) for p in parts) > max_chars: | |
| break | |
| return "\n".join(parts)[:max_chars] | |
| return str(value)[:max_chars] | |
| def _clean_text(text: str) -> str: | |
| text = text.replace("\x00", "") | |
| text = re.sub(r"[ \t]+\n", "\n", text) | |
| text = re.sub(r"\n{5,}", "\n\n\n", text) | |
| return text.strip() | |
| def normalize_tool_call(value: Any) -> Optional[dict[str, Any]]: | |
| """Normalize OpenAI/Gemma-style tool calls for Gemma 4's chat template. | |
| Gemma 4's official template expects assistant messages to carry | |
| ``tool_calls=[{"function": {"name": ..., "arguments": ...}}]``. Do not | |
| stringify tool calls into content; doing so trains the wrong special-token | |
| protocol for this model. | |
| """ | |
| value = _json_loads_maybe(value) | |
| if not isinstance(value, dict): | |
| return None | |
| if isinstance(value.get("function"), dict): | |
| fn = dict(value["function"]) | |
| name = fn.get("name") or value.get("name") or value.get("tool_name") or "unknown" | |
| args = fn.get("arguments", value.get("arguments", {})) | |
| else: | |
| name = value.get("name") or value.get("tool_name") or value.get("function_name") or "unknown" | |
| args = value.get("arguments", value.get("args", value.get("parameters", {}))) | |
| args = _json_loads_maybe(args) | |
| if args is None: | |
| args = {} | |
| call: dict[str, Any] = {"function": {"name": str(name), "arguments": args}} | |
| if value.get("id") is not None: | |
| call["id"] = str(value["id"]) | |
| if value.get("type") is not None: | |
| call["type"] = value["type"] | |
| return call | |
| def normalize_tool_calls(value: Any) -> list[dict[str, Any]]: | |
| value = _json_loads_maybe(value) | |
| if value is None: | |
| return [] | |
| raw_calls = value if isinstance(value, list) else [value] | |
| calls = [] | |
| for raw in raw_calls: | |
| call = normalize_tool_call(raw) | |
| if call is not None: | |
| calls.append(call) | |
| return calls | |
| def normalize_tool_response(value: Any) -> dict[str, Any]: | |
| value = _json_loads_maybe(value) | |
| if isinstance(value, dict): | |
| name = value.get("name") or value.get("tool_name") or value.get("function_name") or "unknown" | |
| response = value.get("response", value.get("content", value.get("observation", value.get("result", value)))) | |
| return {"name": str(name), "response": _json_loads_maybe(response)} | |
| return {"name": "unknown", "response": value} | |
| def normalize_tool_schema(value: Any) -> Optional[dict[str, Any]]: | |
| """Normalize common HF tool schemas to Gemma's function declaration shape.""" | |
| value = _json_loads_maybe(value) | |
| if not isinstance(value, dict): | |
| return None | |
| if value.get("type") == "function" and isinstance(value.get("function"), dict): | |
| fn = dict(value["function"]) | |
| else: | |
| fn = dict(value) | |
| name = fn.get("name") or fn.get("tool_name") or fn.get("function_name") | |
| if not name: | |
| return None | |
| parameters = fn.get("parameters") or fn.get("schema") or fn.get("input_schema") or {} | |
| return { | |
| "type": "function", | |
| "function": { | |
| "name": str(name), | |
| "description": fn.get("description", ""), | |
| "parameters": _json_loads_maybe(parameters), | |
| }, | |
| } | |
| def normalize_tools(value: Any) -> list[dict[str, Any]]: | |
| value = _json_loads_maybe(value) | |
| if value is None: | |
| return [] | |
| raw_tools = value if isinstance(value, list) else [value] | |
| tools = [] | |
| for raw in raw_tools: | |
| tool = normalize_tool_schema(raw) | |
| if tool is not None: | |
| tools.append(tool) | |
| return tools | |
| def normalize_messages(value: Any) -> list[dict[str, Any]]: | |
| """Normalize heterogeneous chat rows while preserving structured tool calls.""" | |
| value = _json_loads_maybe(value) | |
| if not isinstance(value, list): | |
| return [] | |
| messages: list[dict[str, Any]] = [] | |
| for item in value: | |
| item = _json_loads_maybe(item) | |
| if not isinstance(item, dict): | |
| continue | |
| raw_role = item.get("role") or item.get("from") or item.get("speaker") or item.get("author") | |
| role = ROLE_MAP.get(str(raw_role).lower(), None) if raw_role is not None else None | |
| if role is None: | |
| if "system_prompt" in item: | |
| role = "system" | |
| elif "observation" in item: | |
| role = "tool" | |
| else: | |
| role = "user" | |
| content = ( | |
| item.get("content") | |
| if item.get("content") is not None | |
| else item.get("value") | |
| if item.get("value") is not None | |
| else item.get("text") | |
| if item.get("text") is not None | |
| else item.get("message") | |
| if item.get("message") is not None | |
| else item.get("system_prompt") | |
| if item.get("system_prompt") is not None | |
| else item.get("observation") | |
| ) | |
| content_s = _clean_text(_stringify(content)) | |
| msg: dict[str, Any] = {"role": role, "content": content_s} | |
| tool_calls = normalize_tool_calls(item.get("tool_calls") or item.get("tool_call")) | |
| if tool_calls: | |
| msg["tool_calls"] = tool_calls | |
| if item.get("tool_call_id") is not None: | |
| msg["tool_call_id"] = str(item["tool_call_id"]) | |
| if item.get("name") is not None: | |
| msg["name"] = str(item["name"]) | |
| if item.get("reasoning") is not None: | |
| msg["reasoning"] = _stringify(item["reasoning"]) | |
| if item.get("reasoning_content") is not None: | |
| msg["reasoning_content"] = _stringify(item["reasoning_content"]) | |
| if item.get("tool_responses") is not None: | |
| responses = _json_loads_maybe(item["tool_responses"]) | |
| raw_responses = responses if isinstance(responses, list) else [responses] | |
| msg["tool_responses"] = [normalize_tool_response(r) for r in raw_responses] | |
| # Keep assistant tool-call-only messages and tool responses even when content is empty. | |
| if msg["content"] or msg.get("tool_calls") or msg.get("tool_responses") or role == "tool": | |
| messages.append(msg) | |
| return messages | |
| def _message_with_flattened_tool_calls(message: dict[str, Any]) -> dict[str, str]: | |
| role = str(message.get("role", "user")) | |
| content = _stringify(message.get("content", "")) | |
| if message.get("tool_calls"): | |
| content = (content + "\n" if content else "") + "<tool_call>\n" + _json_dumps(message["tool_calls"]) + "\n</tool_call>" | |
| if message.get("tool_responses"): | |
| content = (content + "\n" if content else "") + "<tool_response>\n" + _json_dumps(message["tool_responses"]) + "\n</tool_response>" | |
| return {"role": role, "content": _clean_text(content)} | |
| def _gemma4_text_content_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| """Convert text-only chat messages to Gemma 4's multimodal content-list form. | |
| Gemma 4's processor examples use ``content=[{"type": "text", "text": ...}]`` | |
| even for text-only prompts. Try that form first so tool/chat samples follow | |
| the same path as inference, then fall back to plain strings for older templates. | |
| """ | |
| out: list[dict[str, Any]] = [] | |
| for msg in messages: | |
| m = dict(msg) | |
| content = m.get("content", "") | |
| if isinstance(content, str): | |
| m["content"] = [{"type": "text", "text": content}] | |
| out.append(m) | |
| return out | |
| def render_messages_with_template( | |
| chat_renderer: Any, | |
| messages: list[dict[str, Any]], | |
| *, | |
| tools: Optional[Any] = None, | |
| enable_thinking: bool = False, | |
| ) -> str: | |
| """Render messages using Gemma 4's official chat/function-calling template. | |
| The preferred path passes structured ``tools`` and message ``tool_calls`` to | |
| ``apply_chat_template``. Fallbacks only flatten tool fields if the row schema | |
| is malformed or a downstream tokenizer build rejects the structured form. | |
| """ | |
| if not messages: | |
| return "" | |
| norm_tools = normalize_tools(tools) | |
| def sanitized(convert_tool_to_user: bool, preserve_tool_calls: bool, gemma4_content_lists: bool = False) -> list[dict[str, Any]]: | |
| out: list[dict[str, Any]] = [] | |
| for raw in messages: | |
| m = dict(raw) | |
| role = str(m.get("role", "user")) | |
| if role == "developer": | |
| role = "system" | |
| if role not in {"system", "user", "assistant", "tool"}: | |
| role = "user" | |
| if role == "tool" and convert_tool_to_user: | |
| flat = _message_with_flattened_tool_calls(m) | |
| name = m.get("name") or m.get("tool_call_id") or "unknown" | |
| out.append({ | |
| "role": "user", | |
| "content": "<tool_response>\n" + _json_dumps({"name": name, "response": flat["content"]}) + "\n</tool_response>", | |
| "tool_calls": [], | |
| }) | |
| continue | |
| m["role"] = role | |
| m["content"] = _stringify(m.get("content", "")) | |
| if role != "tool": | |
| # Gemma's current template indexes message['tool_calls'] directly. | |
| m.setdefault("tool_calls", []) | |
| if not preserve_tool_calls: | |
| flat = _message_with_flattened_tool_calls(m) | |
| flat["role"] = role | |
| out.append(flat) | |
| else: | |
| out.append(m) | |
| return out | |
| attempts = ( | |
| (False, True, True, True), # native Gemma tool declarations + assistant tool_calls | |
| (False, True, True, False), | |
| (True, True, True, True), # convert role:tool if the template rejects it | |
| (True, True, True, False), | |
| (True, True, False, False), # keep assistant tool_calls without tools= | |
| (True, False, False, False), # final template attempt: flatten tool payloads | |
| ) | |
| for convert_tool, preserve_tool_calls, include_tools, include_thinking in attempts: | |
| base_messages = sanitized(convert_tool, preserve_tool_calls) | |
| candidates = (_gemma4_text_content_messages(base_messages), base_messages) | |
| for candidate_messages in candidates: | |
| try: | |
| kwargs: dict[str, Any] = {"tokenize": False, "add_generation_prompt": False} | |
| if include_tools and norm_tools: | |
| kwargs["tools"] = norm_tools | |
| if include_thinking: | |
| kwargs["enable_thinking"] = enable_thinking | |
| return chat_renderer.apply_chat_template(candidate_messages, **kwargs) | |
| except TypeError: | |
| # Older templates may not accept tools=, enable_thinking=, or content-list messages. | |
| continue | |
| except Exception: | |
| continue | |
| tok = getattr(chat_renderer, "tokenizer", chat_renderer) | |
| pieces = [] | |
| bos = getattr(tok, "bos_token", None) or "" | |
| if bos: | |
| pieces.append(bos) | |
| if norm_tools: | |
| pieces.append("<start_of_turn>system\nAvailable tools:\n<tools>\n" + _json_dumps(norm_tools) + "\n</tools><end_of_turn>") | |
| for m in sanitized(True, False): | |
| pieces.append(f"<start_of_turn>{m['role']}\n{m['content']}<end_of_turn>") | |
| return "\n".join(pieces) | |
| def flatten_wikipedia_sections(sections: Any, *, max_chars: int = 100_000) -> str: | |
| sections = _json_loads_maybe(sections) | |
| parts: list[str] = [] | |
| def visit(x: Any, depth: int = 0) -> None: | |
| if sum(len(p) for p in parts) > max_chars: | |
| return | |
| x = _json_loads_maybe(x) | |
| if isinstance(x, list): | |
| for item in x: | |
| visit(item, depth) | |
| return | |
| if isinstance(x, dict): | |
| heading = x.get("name") or x.get("title") or x.get("heading") or x.get("toclevel") | |
| if heading and isinstance(heading, str): | |
| parts.append("#" * min(depth + 1, 4) + " " + heading) | |
| for key in ("text", "content", "value", "paragraph", "paragraphs"): | |
| if key in x and x[key]: | |
| s = _stringify(x[key], max_chars=max_chars) | |
| if s: | |
| parts.append(s) | |
| for key in ("sections", "subsections", "children", "items"): | |
| if key in x and x[key]: | |
| visit(x[key], depth + 1) | |
| return | |
| s = _stringify(x, max_chars=max_chars) | |
| if s: | |
| parts.append(s) | |
| visit(sections) | |
| return _clean_text("\n".join(parts)[:max_chars]) | |
| def format_row(row: dict[str, Any], kind: str, tokenizer: Any, max_chars: int) -> str: | |
| # Tool/action/chat schemas first. Preserve Gemma 4-native tool schemas by | |
| # passing `tools=` into the chat template when present, with an in-band fallback. | |
| tool_fields = ( | |
| row.get("tools") | |
| or row.get("tool_schemas") | |
| or row.get("available_tools") | |
| or row.get("functions") | |
| ) | |
| if kind == "mobile_actions": | |
| metadata = _json_loads_maybe(row.get("metadata", "train")) | |
| if isinstance(metadata, str) and metadata.strip().lower() == "eval": | |
| return "" | |
| if isinstance(metadata, dict): | |
| split = str(metadata.get("split") or metadata.get("subset") or metadata.get("partition") or "train").lower() | |
| if split == "eval": | |
| return "" | |
| if kind in {"swe_zero", "mobile_actions", "when2call_tool"} and row.get("messages") is not None: | |
| messages = normalize_messages(row.get("messages")) | |
| text = render_messages_with_template(tokenizer, messages, tools=tool_fields, enable_thinking=False) | |
| return _clean_text(text[:max_chars]) | |
| if kind == "hermes_tool" and row.get("conversations") is not None: | |
| messages = normalize_messages(row.get("conversations")) | |
| return _clean_text(render_messages_with_template(tokenizer, messages, tools=tool_fields, enable_thinking=False)[:max_chars]) | |
| if kind == "nebius_swe_agent" and row.get("trajectory") is not None: | |
| messages = normalize_messages(row.get("trajectory")) | |
| text = render_messages_with_template(tokenizer, messages, tools=tool_fields, enable_thinking=False) | |
| if row.get("generated_patch"): | |
| text += "\n\nFinal generated patch:\n```diff\n" + str(row["generated_patch"])[:30_000] + "\n```" | |
| if row.get("eval_logs"): | |
| text += "\n\nEvaluation logs:\n```text\n" + str(row["eval_logs"])[:20_000] + "\n```" | |
| return _clean_text(text[:max_chars]) | |
| if kind == "tau2_tool": | |
| for key in ("messages", "conversation", "conversations", "trace", "trajectory", "rollout"): | |
| if row.get(key) is not None: | |
| messages = normalize_messages(row.get(key)) | |
| if messages: | |
| text = render_messages_with_template(tokenizer, messages, tools=tool_fields, enable_thinking=False) | |
| return _clean_text(text[:max_chars]) | |
| # Tau2 variants sometimes store trace/tool state in wide JSON columns. | |
| parts = [] | |
| for key in ( | |
| "scenario", | |
| "instruction", | |
| "user_goal", | |
| "policy", | |
| "tools", | |
| "tool_schemas", | |
| "available_tools", | |
| "trace", | |
| "trajectory", | |
| "transcript", | |
| "reward_info", | |
| ): | |
| if key in row and row[key] is not None: | |
| parts.append(f"{key}:\n{_stringify(row[key], max_chars=max_chars)}") | |
| if parts: | |
| return _clean_text("\n\n".join(parts)[:max_chars]) | |
| if row.get("messages") is not None: | |
| messages = normalize_messages(row.get("messages")) | |
| if messages: | |
| return _clean_text(render_messages_with_template(tokenizer, messages, tools=tool_fields, enable_thinking=False)[:max_chars]) | |
| if row.get("conversations") is not None: | |
| messages = normalize_messages(row.get("conversations")) | |
| if messages: | |
| return _clean_text(render_messages_with_template(tokenizer, messages, tools=tool_fields, enable_thinking=False)[:max_chars]) | |
| # Dataset-specific raw text/code renderers. | |
| if kind == "wikipedia": | |
| title = row.get("name") or row.get("title") or "" | |
| url = row.get("url") or "" | |
| sections = flatten_wikipedia_sections(row.get("sections"), max_chars=max_chars) | |
| if sections: | |
| return _clean_text(f"Wikipedia article: {title}\nURL: {url}\n\n{sections}"[:max_chars]) | |
| if kind == "wikipedia_monthly": | |
| title = row.get("title") or row.get("name") or "" | |
| url = row.get("url") or row.get("source") or "" | |
| body = row.get("text") or row.get("content") or row.get("markdown") or row.get("article") | |
| body_s = _stringify(body, max_chars=max_chars) | |
| if body_s: | |
| return _clean_text(f"Wikipedia article: {title}\nURL: {url}\n\n{body_s}"[:max_chars]) | |
| if kind == "scotch": | |
| parts = [ | |
| f"Repository: {row.get('repository_name', '')}", | |
| f"Path: {row.get('function_path', '')}", | |
| f"Language: {row.get('language', '')}", | |
| f"License: {row.get('license', '')}", | |
| ] | |
| if row.get("docstring"): | |
| parts.append("Docstring:\n" + str(row.get("docstring"))) | |
| if row.get("context"): | |
| parts.append("Code context:\n```\n" + str(row.get("context"))[:40_000] + "\n```") | |
| if row.get("function"): | |
| parts.append("Function:\n```\n" + str(row.get("function"))[:40_000] + "\n```") | |
| # Only return the Scotch-specific rendering if actual code/doc content exists; | |
| # otherwise fall through to the generic text renderer below. | |
| if any(row.get(k) for k in ("docstring", "context", "function", "code", "text")): | |
| return _clean_text("\n\n".join(parts)[:max_chars]) | |
| if kind == "github_archive": | |
| metadata = _json_loads_maybe(row.get("metadata")) | |
| metadata = metadata if isinstance(metadata, dict) else {} | |
| body = row.get("text") or row.get("body") or row.get("content") or "" | |
| parts = ["GitHub Archive item"] | |
| source = metadata.get("source") or row.get("source") | |
| repo = metadata.get("repo") or metadata.get("repository") or row.get("repo") or row.get("repository") | |
| url = metadata.get("url") or row.get("url") | |
| license_name = metadata.get("license") or row.get("license") | |
| created = row.get("created") or row.get("created_at") | |
| added = row.get("added") | |
| for label, value in ( | |
| ("Source", source), | |
| ("Repository", repo), | |
| ("URL", url), | |
| ("License", license_name), | |
| ("Created", created), | |
| ("Added", added), | |
| ): | |
| if value: | |
| parts.append(f"{label}: {value}") | |
| if body: | |
| parts.append("\n" + str(body)[:max_chars]) | |
| text = _clean_text("\n".join(parts)[:max_chars]) | |
| if text and len(text) >= 32: | |
| return text | |
| # Generic fallback: prefer a single text field, then concatenate useful fields. | |
| for key in TEXT_KEYS_PREFERENCE: | |
| if key in row and isinstance(row[key], str) and len(row[key].strip()) > 200: | |
| return _clean_text(row[key][:max_chars]) | |
| parts = [] | |
| for key in TEXT_KEYS_PREFERENCE: | |
| if key in row and row[key] is not None: | |
| s = _stringify(row[key], max_chars=max_chars) | |
| if s: | |
| parts.append(f"{key}:\n{s}") | |
| if sum(len(p) for p in parts) > max_chars: | |
| break | |
| return _clean_text("\n\n".join(parts)[:max_chars]) | |
| def expand_dataset_specs(specs: Iterable[DatasetSpec]) -> list[DatasetSpec]: | |
| expanded: list[DatasetSpec] = [] | |
| for spec in specs: | |
| if isinstance(spec.config, tuple): | |
| per_weight = spec.weight / len(spec.config) | |
| for cfg in spec.config: | |
| expanded.append(dataclasses.replace(spec, config=cfg, weight=per_weight)) | |
| else: | |
| expanded.append(spec) | |
| total = sum(s.weight for s in expanded) | |
| return [dataclasses.replace(s, weight=s.weight / total) for s in expanded] | |
| def _is_dataset_script_unsupported_error(exc: BaseException) -> bool: | |
| text = str(exc) | |
| return "Dataset scripts are no longer supported" in text or "trust_remote_code is not supported anymore" in text | |
| @contextlib.contextmanager | |
| def _dataset_load_timeout(seconds: int | float, label: str): | |
| """Interrupt dataset metadata construction that otherwise appears to hang. | |
| This guard wraps only load_dataset construction, not streaming iteration. It | |
| is active on the POSIX main-thread path used by the B300 training run. | |
| """ | |
| seconds = int(seconds or 0) | |
| if seconds <= 0 or threading.current_thread() is not threading.main_thread() or not hasattr(signal, "SIGALRM"): | |
| yield | |
| return | |
| previous_handler = signal.getsignal(signal.SIGALRM) | |
| def _handler(signum, frame): # noqa: ARG001 | |
| raise TimeoutError(f"Timed out after {seconds}s while loading dataset metadata for {label}.") | |
| signal.signal(signal.SIGALRM, _handler) | |
| signal.setitimer(signal.ITIMER_REAL, float(seconds)) | |
| try: | |
| yield | |
| finally: | |
| signal.setitimer(signal.ITIMER_REAL, 0.0) | |
| signal.signal(signal.SIGALRM, previous_handler) | |
| def _load_dataset_with_timeout(args: argparse.Namespace, label: str, *load_args, **load_kwargs): | |
| start = time.monotonic() | |
| with _dataset_load_timeout(getattr(args, "dataset_load_timeout_sec", 0), label): | |
| ds = load_dataset(*load_args, **load_kwargs) | |
| print(f"[data] loaded {label} in {time.monotonic() - start:.1f}s") | |
| return ds | |
| def _candidate_parquet_revisions(spec: DatasetSpec) -> list[str]: | |
| # HF dataset-viewer converted Parquet files are published on refs/convert/parquet. | |
| # Some tooling also exposes the same branch as ~parquet, so keep it as a fallback. | |
| revisions = [] | |
| if spec.revision: | |
| revisions.append(spec.revision) | |
| revisions.extend(["refs/convert/parquet", "~parquet"]) | |
| out: list[str] = [] | |
| for rev in revisions: | |
| if rev not in out: | |
| out.append(rev) | |
| return out | |
| def _match_parquet_files(files: list[str], spec: DatasetSpec) -> list[str]: | |
| split = str(spec.split) | |
| config = str(spec.config) if spec.config is not None else None | |
| prefixes: list[str] = [] | |
| if config: | |
| prefixes.extend([f"{config}/{split}/", f"{config}/{split}-", f"{config}/data/{split}-"]) | |
| prefixes.extend([f"{split}/", f"{split}-", f"data/{split}-"]) | |
| matched = [f for f in files if f.endswith(".parquet") and any(f.startswith(prefix) for prefix in prefixes)] | |
| if matched: | |
| return sorted(matched) | |
| # Last-resort shape match for repos whose converted files use extra nesting. | |
| if config: | |
| needle = f"/{config}/{split}/" | |
| matched = [f for f in files if f.endswith(".parquet") and needle in f"/{f}/"] | |
| if matched: | |
| return sorted(matched) | |
| needle = f"/{split}/" | |
| matched = [f for f in files if f.endswith(".parquet") and needle in f"/{f}/"] | |
| return sorted(matched) | |
| def _hf_parquet_urls_from_dataset_server(spec: DatasetSpec, args: argparse.Namespace) -> tuple[list[str], str] | None: | |
| """Return auto-converted Parquet URLs from the HF Dataset Viewer API when available.""" | |
| api_url = f"https://datasets-server.huggingface.co/parquet?dataset={quote(spec.name, safe='')}" | |
| req = Request(api_url) | |
| if args.hf_token: | |
| req.add_header("Authorization", f"Bearer {args.hf_token}") | |
| try: | |
| with urlopen(req, timeout=60) as resp: | |
| payload = json.loads(resp.read().decode("utf-8")) | |
| except Exception: | |
| return None | |
| parquet_files = payload.get("parquet_files") or [] | |
| if not isinstance(parquet_files, list): | |
| return None | |
| wanted_config = str(spec.config) if spec.config is not None else None | |
| wanted_split = str(spec.split) | |
| urls: list[str] = [] | |
| for item in parquet_files: | |
| if not isinstance(item, dict): | |
| continue | |
| if str(item.get("split")) != wanted_split: | |
| continue | |
| if wanted_config is not None and str(item.get("config")) != wanted_config: | |
| continue | |
| url = item.get("url") | |
| if isinstance(url, str) and url.endswith(".parquet"): | |
| urls.append(url) | |
| if not urls: | |
| return None | |
| return sorted(urls), "dataset-viewer/parquet" | |
| def _hf_parquet_urls(spec: DatasetSpec, args: argparse.Namespace) -> tuple[list[str], str]: | |
| viewer_result = _hf_parquet_urls_from_dataset_server(spec, args) | |
| if viewer_result is not None: | |
| return viewer_result | |
| api = HfApi(token=args.hf_token or None) | |
| errors: list[str] = ["dataset-viewer/parquet: unavailable or no matching files"] | |
| for revision in _candidate_parquet_revisions(spec): | |
| try: | |
| files = api.list_repo_files(repo_id=spec.name, repo_type="dataset", revision=revision) | |
| except Exception as exc: # network/auth/revision absence; try next known parquet branch | |
| errors.append(f"{revision}: {type(exc).__name__}: {exc}") | |
| continue | |
| matched = _match_parquet_files(files, spec) | |
| if matched: | |
| encoded_rev = quote(revision, safe="") | |
| urls = [ | |
| f"https://huggingface.co/datasets/{spec.name}/resolve/{encoded_rev}/{quote(path, safe='/')}" | |
| for path in matched | |
| ] | |
| return urls, revision | |
| errors.append(f"{revision}: no parquet files matched config={spec.config!r} split={spec.split!r}") | |
| raise RuntimeError( | |
| f"Could not find converted Parquet files for dataset {spec.name!r} " | |
| f"config={spec.config!r} split={spec.split!r}. Tried: " + " | ".join(errors) | |
| ) | |
| def _load_spec_dataset(spec: DatasetSpec, args: argparse.Namespace): | |
| load_kwargs = dict(split=spec.split, streaming=True) | |
| if args.hf_token: | |
| load_kwargs["token"] = args.hf_token | |
| if spec.data_files is not None: | |
| load_kwargs["data_files"] = spec.data_files | |
| label = f"{spec.name} config={spec.config!r} split={spec.split!r} loader={spec.loader!r}" | |
| if spec.loader == "hf_parquet": | |
| urls, revision = _hf_parquet_urls(spec, args) | |
| print(f"[data] {spec.name} config={spec.config!r} split={spec.split!r}: loading {len(urls)} converted parquet file(s) from {revision}") | |
| parquet_kwargs = dict(split=spec.split, streaming=True, data_files={spec.split: urls}) | |
| if args.hf_token: | |
| parquet_kwargs["token"] = args.hf_token | |
| return _load_dataset_with_timeout(args, label, "parquet", **parquet_kwargs) | |
| if spec.loader == "parquet": | |
| if spec.data_files is None: | |
| raise ValueError(f"DatasetSpec {spec.name!r} uses loader='parquet' but does not define data_files.") | |
| print(f"[data] {spec.name} config={spec.config!r} split={spec.split!r}: loading explicit parquet data_files") | |
| parquet_kwargs = dict(split=spec.split, streaming=True, data_files=spec.data_files) | |
| if args.hf_token: | |
| parquet_kwargs["token"] = args.hf_token | |
| return _load_dataset_with_timeout(args, label, "parquet", **parquet_kwargs) | |
| dataset_name_or_loader = spec.loader or spec.name | |
| try: | |
| if spec.config: | |
| if spec.revision: | |
| load_kwargs["revision"] = spec.revision | |
| return _load_dataset_with_timeout(args, label, dataset_name_or_loader, spec.config, **load_kwargs) | |
| if spec.revision: | |
| load_kwargs["revision"] = spec.revision | |
| return _load_dataset_with_timeout(args, label, dataset_name_or_loader, **load_kwargs) | |
| except (RuntimeError, TimeoutError) as exc: | |
| should_fallback = isinstance(exc, TimeoutError) or _is_dataset_script_unsupported_error(exc) | |
| if not should_fallback or not args.allow_hf_parquet_fallback: | |
| raise | |
| urls, revision = _hf_parquet_urls(spec, args) | |
| print( | |
| f"[data] {spec.name} config={spec.config!r} split={spec.split!r}: primary loader failed/stalled " | |
| f"({type(exc).__name__}: {exc}); falling back to {len(urls)} converted parquet file(s) from {revision}" | |
| ) | |
| parquet_kwargs = dict(split=spec.split, streaming=True, data_files={spec.split: urls}) | |
| if args.hf_token: | |
| parquet_kwargs["token"] = args.hf_token | |
| fallback_label = f"{spec.name} config={spec.config!r} split={spec.split!r} loader='hf_parquet_fallback'" | |
| return _load_dataset_with_timeout(args, fallback_label, "parquet", **parquet_kwargs) | |
| class FormattedTextStream: | |
| """Lazy row formatter that yields only {"text": str} records. | |
| Keeping this as a small Python iterable avoids Hugging Face IterableDataset | |
| feature inference for map/filter/select_columns. That feature inference is | |
| the source of the StopIteration crash seen before interleave construction | |
| when any lazily filtered source has no early usable examples. | |
| """ | |
| def __init__( | |
| self, | |
| raw_stream: Any, | |
| *, | |
| kind: str, | |
| tokenizer: Any, | |
| max_chars: int, | |
| min_chars: int, | |
| label: str, | |
| startup_heartbeat_sec: float = 30.0, | |
| ): | |
| self.raw_stream = raw_stream | |
| self.kind = kind | |
| self.tokenizer = tokenizer | |
| self.max_chars = int(max_chars) | |
| self.min_chars = int(min_chars) | |
| self.label = label | |
| self.startup_heartbeat_sec = float(startup_heartbeat_sec) | |
| def __iter__(self) -> Iterator[dict[str, str]]: | |
| rows_seen = 0 | |
| rows_yielded = 0 | |
| rows_short_or_empty = 0 | |
| start = time.monotonic() | |
| last_log = start | |
| for row in self.raw_stream: | |
| rows_seen += 1 | |
| try: | |
| text = format_row(row, self.kind, self.tokenizer, self.max_chars) | |
| except Exception as exc: | |
| if rows_seen <= 3: | |
| print( | |
| f"[data] formatter warning for {self.label}: row={rows_seen} " | |
| f"raised {type(exc).__name__}: {exc}", | |
| flush=True, | |
| ) | |
| text = "" | |
| if isinstance(text, str) and len(text) >= self.min_chars: | |
| rows_yielded += 1 | |
| if rows_yielded == 1: | |
| print( | |
| f"[data] {self.label}: first formatted text row after raw_rows={rows_seen}; " | |
| f"chars={len(text)}", | |
| flush=True, | |
| ) | |
| yield {"text": text} | |
| continue | |
| rows_short_or_empty += 1 | |
| now = time.monotonic() | |
| if rows_yielded == 0 and self.startup_heartbeat_sec > 0 and now - last_log >= self.startup_heartbeat_sec: | |
| last_log = now | |
| print( | |
| f"[data] {self.label}: still scanning for first usable text row; " | |
| f"raw_rows={rows_seen} short_or_empty={rows_short_or_empty} elapsed={now - start:.1f}s", | |
| flush=True, | |
| ) | |
| if rows_yielded == 0: | |
| print( | |
| f"[data] {self.label}: exhausted without yielding usable text; " | |
| f"raw_rows={rows_seen} short_or_empty={rows_short_or_empty}", | |
| flush=True, | |
| ) | |
| class WeightedInterleavedTextStream: | |
| """Weighted lazy interleaver for same-schema {"text": str} streams. | |
| This intentionally avoids datasets.interleave_datasets(), which calls | |
| _resolve_features() and may StopIteration-peek each lazy stream before | |
| training even starts. | |
| """ | |
| def __init__( | |
| self, | |
| streams: list[Any], | |
| probabilities: list[float], | |
| labels: list[str], | |
| *, | |
| seed: int, | |
| stopping_strategy: str, | |
| startup_heartbeat_sec: float = 30.0, | |
| ): | |
| if len(streams) != len(probabilities) or len(streams) != len(labels): | |
| raise ValueError("streams, probabilities, and labels must have the same length") | |
| self.streams = streams | |
| self.probabilities = probabilities | |
| self.labels = labels | |
| self.seed = int(seed) | |
| self.stopping_strategy = stopping_strategy | |
| self.startup_heartbeat_sec = float(startup_heartbeat_sec) | |
| def __iter__(self) -> Iterator[dict[str, str]]: | |
| rng = random.Random(self.seed) | |
| iterators = [iter(s) for s in self.streams] | |
| active = list(range(len(iterators))) | |
| yielded = 0 | |
| exhausted_count = 0 | |
| start = time.monotonic() | |
| last_log = start | |
| print( | |
| f"[data] custom weighted interleaver started with {len(active)} active source(s); " | |
| f"stopping_strategy={self.stopping_strategy}", | |
| flush=True, | |
| ) | |
| while active: | |
| now = time.monotonic() | |
| if yielded == 0 and self.startup_heartbeat_sec > 0 and now - last_log >= self.startup_heartbeat_sec: | |
| last_log = now | |
| print( | |
| f"[data] custom interleaver waiting for first yielded row; " | |
| f"active_sources={len(active)} elapsed={now - start:.1f}s", | |
| flush=True, | |
| ) | |
| weights = [max(float(self.probabilities[i]), 0.0) for i in active] | |
| total = sum(weights) | |
| if total <= 0: | |
| chosen_pos = rng.randrange(len(active)) | |
| else: | |
| r = rng.random() * total | |
| cumulative = 0.0 | |
| chosen_pos = len(active) - 1 | |
| for pos, w in enumerate(weights): | |
| cumulative += w | |
| if r <= cumulative: | |
| chosen_pos = pos | |
| break | |
| source_idx = active[chosen_pos] | |
| try: | |
| row = next(iterators[source_idx]) | |
| except StopIteration: | |
| exhausted_count += 1 | |
| print( | |
| f"[data] custom interleaver source exhausted: index={source_idx} " | |
| f"label={self.labels[source_idx]!r}; yielded_rows={yielded}; " | |
| f"remaining_active_before_drop={len(active)}", | |
| flush=True, | |
| ) | |
| if self.stopping_strategy == "first_exhausted": | |
| return | |
| del active[chosen_pos] | |
| continue | |
| yielded += 1 | |
| if yielded == 1: | |
| print( | |
| f"[data] custom interleaver yielded first row from source index={source_idx} " | |
| f"label={self.labels[source_idx]!r}", | |
| flush=True, | |
| ) | |
| yield row | |
| print( | |
| f"[data] custom interleaver exhausted all sources; yielded_rows={yielded} " | |
| f"exhausted_sources={exhausted_count}", | |
| flush=True, | |
| ) | |
| class ShuffleBufferTextStream: | |
| """Small approximate shuffle buffer for arbitrary Python iterables.""" | |
| def __init__(self, stream: Any, *, buffer_size: int, seed: int, startup_heartbeat_sec: float = 30.0): | |
| self.stream = stream | |
| self.buffer_size = int(buffer_size) | |
| self.seed = int(seed) | |
| self.startup_heartbeat_sec = float(startup_heartbeat_sec) | |
| def __iter__(self) -> Iterator[dict[str, str]]: | |
| if self.buffer_size <= 1: | |
| yield from self.stream | |
| return | |
| rng = random.Random(self.seed) | |
| buffer: list[dict[str, str]] = [] | |
| input_rows = 0 | |
| output_rows = 0 | |
| start = time.monotonic() | |
| last_log = start | |
| print(f"[data] custom global shuffle started; buffer={self.buffer_size}", flush=True) | |
| for row in self.stream: | |
| input_rows += 1 | |
| if len(buffer) < self.buffer_size: | |
| buffer.append(row) | |
| now = time.monotonic() | |
| if output_rows == 0 and self.startup_heartbeat_sec > 0 and now - last_log >= self.startup_heartbeat_sec: | |
| last_log = now | |
| print( | |
| f"[data] custom global shuffle filling initial buffer; " | |
| f"buffered={len(buffer)}/{self.buffer_size} input_rows={input_rows} elapsed={now - start:.1f}s", | |
| flush=True, | |
| ) | |
| continue | |
| idx = rng.randrange(len(buffer)) | |
| out = buffer[idx] | |
| buffer[idx] = row | |
| output_rows += 1 | |
| if output_rows == 1: | |
| print( | |
| f"[data] custom global shuffle yielded first row after filling buffer={self.buffer_size}", | |
| flush=True, | |
| ) | |
| yield out | |
| rng.shuffle(buffer) | |
| for out in buffer: | |
| output_rows += 1 | |
| yield out | |
| print( | |
| f"[data] custom global shuffle exhausted; input_rows={input_rows} output_rows={output_rows}", | |
| flush=True, | |
| ) | |
| def probe_source_for_text(raw_stream: Any, spec: DatasetSpec, args: argparse.Namespace, tokenizer: Any, label: str) -> bool: | |
| """Return True if the formatter can produce at least N usable text rows early. | |
| This is intentionally bounded by source_probe_max_scan_rows so one malformed | |
| large source cannot hang the job during validation. The raw HF streaming | |
| datasets are re-iterable, so this diagnostic pass does not consume the | |
| training stream used later. | |
| """ | |
| needed = int(getattr(args, "source_probe_examples", 0) or 0) | |
| if needed <= 0: | |
| return True | |
| max_scan = max(needed, int(getattr(args, "source_probe_max_scan_rows", 512) or 512)) | |
| found = 0 | |
| raw_seen = 0 | |
| first_preview = None | |
| start = time.monotonic() | |
| for raw_seen, row in enumerate(raw_stream, start=1): | |
| try: | |
| text = format_row(row, spec.kind, tokenizer, args.max_chars_per_sample) | |
| except Exception as exc: | |
| if raw_seen <= 3: | |
| print( | |
| f"[data] source probe formatter warning for {label}: row={raw_seen} " | |
| f"raised {type(exc).__name__}: {exc}", | |
| flush=True, | |
| ) | |
| text = "" | |
| if isinstance(text, str) and len(text) >= args.min_chars_per_sample: | |
| found += 1 | |
| if first_preview is None: | |
| first_preview = re.sub(r"\s+", " ", text[:220]).strip() | |
| if found >= needed: | |
| print( | |
| f"[data] source probe ok for {label}: found={found}/{needed} usable row(s) " | |
| f"within raw_rows={raw_seen} in {time.monotonic() - start:.1f}s; " | |
| f"first_preview={first_preview!r}", | |
| flush=True, | |
| ) | |
| return True | |
| if raw_seen >= max_scan: | |
| break | |
| print( | |
| f"[data] source probe failed for {label}: found={found}/{needed} usable row(s) " | |
| f"within raw_rows={raw_seen} max_scan={max_scan} in {time.monotonic() - start:.1f}s", | |
| flush=True, | |
| ) | |
| return False | |
| def load_streaming_mix(args: argparse.Namespace, tokenizer: Any): | |
| datasets = [] | |
| probs = [] | |
| labels = [] | |
| expanded = expand_dataset_specs(build_dataset_mix(args)) | |
| for spec_idx, spec in enumerate(expanded): | |
| label = f"source {spec_idx + 1}/{len(expanded)} {spec.name} config={spec.config!r} kind={spec.kind}" | |
| print( | |
| f"[data] loading source {spec_idx + 1}/{len(expanded)}: " | |
| f"name={spec.name!r} config={spec.config!r} split={spec.split!r} weight={spec.weight:.4f} kind={spec.kind}" | |
| ) | |
| source_start = time.monotonic() | |
| raw_ds = _load_spec_dataset(spec, args) | |
| print( | |
| f"[data] source {spec_idx + 1}/{len(expanded)} constructed after " | |
| f"{time.monotonic() - source_start:.1f}s; validating formatter", | |
| flush=True, | |
| ) | |
| if not probe_source_for_text(raw_ds, spec, args, tokenizer, label): | |
| msg = ( | |
| f"[data] source {spec_idx + 1}/{len(expanded)} produced no usable text during bounded probe: " | |
| f"name={spec.name!r} config={spec.config!r} kind={spec.kind!r}." | |
| ) | |
| if getattr(args, "drop_sources_without_probe_examples", True): | |
| print(msg + " Dropping this source and renormalizing the mix.", flush=True) | |
| continue | |
| raise RuntimeError(msg + " Use --drop_sources_without_probe_examples or fix the formatter/source.") | |
| print(f"[data] source {spec_idx + 1}/{len(expanded)} applying Python text formatter", flush=True) | |
| ds = FormattedTextStream( | |
| raw_ds, | |
| kind=spec.kind, | |
| tokenizer=tokenizer, | |
| max_chars=args.max_chars_per_sample, | |
| min_chars=args.min_chars_per_sample, | |
| label=label, | |
| startup_heartbeat_sec=args.data_startup_heartbeat_sec, | |
| ) | |
| if args.shuffle_buffer_size > 0 and args.shuffle_scope in {"per_source", "both"}: | |
| print( | |
| f"[data] source {spec_idx + 1}/{len(expanded)} enabling per-source custom streaming shuffle " | |
| f"buffer={args.shuffle_buffer_size}", | |
| flush=True, | |
| ) | |
| ds = ShuffleBufferTextStream( | |
| ds, | |
| buffer_size=args.shuffle_buffer_size, | |
| seed=args.seed + spec_idx, | |
| startup_heartbeat_sec=args.data_startup_heartbeat_sec, | |
| ) | |
| print(f"[data] source {spec_idx + 1}/{len(expanded)} ready", flush=True) | |
| datasets.append(ds) | |
| probs.append(spec.weight) | |
| labels.append(label) | |
| if not datasets: | |
| raise RuntimeError("No dataset sources were loaded; every source failed the bounded text probe or was excluded.") | |
| # Normalize exactly after any diagnostic source drops. | |
| total = sum(probs) | |
| probs = [p / total for p in probs] | |
| print( | |
| "[data] all source streams ready; constructing lazy mixed stream " | |
| f"for {len(datasets)} source(s); normalized weights=" | |
| + json.dumps([round(p, 4) for p in probs]), | |
| flush=True, | |
| ) | |
| mixed = WeightedInterleavedTextStream( | |
| datasets, | |
| probs, | |
| labels, | |
| seed=args.seed, | |
| stopping_strategy=args.interleave_stopping_strategy, | |
| startup_heartbeat_sec=args.data_startup_heartbeat_sec, | |
| ) | |
| if args.shuffle_buffer_size > 0 and args.shuffle_scope in {"global", "both"}: | |
| print( | |
| f"[data] applying one custom global streaming shuffle after interleave; " | |
| f"buffer={args.shuffle_buffer_size}", | |
| flush=True, | |
| ) | |
| mixed = ShuffleBufferTextStream( | |
| mixed, | |
| buffer_size=args.shuffle_buffer_size, | |
| seed=args.seed, | |
| startup_heartbeat_sec=args.data_startup_heartbeat_sec, | |
| ) | |
| elif args.shuffle_buffer_size <= 0 or args.shuffle_scope == "none": | |
| print("[data] streaming shuffle disabled", flush=True) | |
| print( | |
| "[data] lazy mixed stream constructed; actual examples are pulled only " | |
| "when the DataLoader requests the first packed batch", | |
| flush=True, | |
| ) | |
| return mixed | |
| class PackedTokenDataset(IterableDataset): | |
| def __init__( | |
| self, | |
| text_stream: Any, | |
| tokenizer: Any, | |
| seq_len: int, | |
| max_tokens_per_sample: int, | |
| seed: int, | |
| startup_log_every_examples: int = 128, | |
| startup_heartbeat_sec: float = 30.0, | |
| ): | |
| super().__init__() | |
| self.text_stream = text_stream | |
| self.tokenizer = tokenizer | |
| self.seq_len = seq_len | |
| self.max_tokens_per_sample = max_tokens_per_sample | |
| self.seed = seed | |
| self.startup_log_every_examples = int(startup_log_every_examples) | |
| self.startup_heartbeat_sec = float(startup_heartbeat_sec) | |
| eos = getattr(tokenizer, "eos_token_id", None) | |
| self.eos_token_id = eos if isinstance(eos, int) else 1 | |
| def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: | |
| worker = get_worker_info() | |
| worker_id = worker.id if worker is not None else 0 | |
| num_workers = worker.num_workers if worker is not None else 1 | |
| buffer: list[int] = [] | |
| rows_seen = 0 | |
| rows_assigned = 0 | |
| rows_with_text = 0 | |
| rows_tokenized = 0 | |
| tokens_added = 0 | |
| blocks_yielded = 0 | |
| start = time.monotonic() | |
| last_log = start | |
| def log_packer(reason: str, *, force: bool = False) -> None: | |
| nonlocal last_log | |
| now = time.monotonic() | |
| if not force: | |
| if self.startup_heartbeat_sec <= 0 or now - last_log < self.startup_heartbeat_sec: | |
| return | |
| last_log = now | |
| print( | |
| "[data] packer " | |
| f"worker={worker_id}/{num_workers} {reason}: " | |
| f"source_rows={rows_seen} assigned_rows={rows_assigned} " | |
| f"text_rows={rows_with_text} tokenized_rows={rows_tokenized} " | |
| f"tokens_added={tokens_added} buffered_tokens={len(buffer)} " | |
| f"yielded_blocks={blocks_yielded} elapsed={now - start:.1f}s", | |
| flush=True, | |
| ) | |
| log_packer("started; waiting for first row from lazy mixed stream", force=True) | |
| for i, row in enumerate(self.text_stream): | |
| rows_seen += 1 | |
| if i % num_workers != worker_id: | |
| continue | |
| rows_assigned += 1 | |
| if rows_assigned == 1: | |
| log_packer("received first assigned row", force=True) | |
| elif self.startup_log_every_examples > 0 and rows_assigned % self.startup_log_every_examples == 0: | |
| log_packer(f"processed {rows_assigned} assigned rows", force=True) | |
| else: | |
| log_packer("still reading/tokenizing") | |
| text = row.get("text") if isinstance(row, dict) else None | |
| if not text: | |
| continue | |
| rows_with_text += 1 | |
| ids = self.tokenizer.encode(text, add_special_tokens=False) | |
| if self.max_tokens_per_sample > 0 and len(ids) > self.max_tokens_per_sample: | |
| ids = ids[: self.max_tokens_per_sample] | |
| ids.append(self.eos_token_id) | |
| rows_tokenized += 1 | |
| tokens_added += len(ids) | |
| if rows_tokenized == 1: | |
| print( | |
| "[data] packer first tokenized row: " | |
| f"chars={len(text)} tokens_with_eos={len(ids)} seq_len={self.seq_len}", | |
| flush=True, | |
| ) | |
| buffer.extend(ids) | |
| while len(buffer) >= self.seq_len: | |
| block = buffer[: self.seq_len] | |
| del buffer[: self.seq_len] | |
| x = torch.tensor(block, dtype=torch.long) | |
| blocks_yielded += 1 | |
| if blocks_yielded <= 4: | |
| log_packer(f"yielding packed block {blocks_yielded}", force=True) | |
| yield { | |
| "input_ids": x, | |
| "attention_mask": torch.ones_like(x, dtype=torch.long), | |
| "labels": x.clone(), | |
| } | |
| log_packer("underlying stream exhausted", force=True) | |
| class WaitHeartbeat: | |
| """Print periodic progress while the main thread is blocked in a known slow call.""" | |
| def __init__(self, message: str, interval_sec: float): | |
| self.message = message | |
| self.interval_sec = float(interval_sec) | |
| self._stop: Optional[threading.Event] = None | |
| self._thread: Optional[threading.Thread] = None | |
| self._start = 0.0 | |
| def __enter__(self): | |
| if self.interval_sec <= 0: | |
| return self | |
| self._stop = threading.Event() | |
| self._start = time.monotonic() | |
| def run() -> None: | |
| assert self._stop is not None | |
| while not self._stop.wait(self.interval_sec): | |
| elapsed = time.monotonic() - self._start | |
| print(f"{self.message} elapsed={elapsed:.1f}s", flush=True) | |
| self._thread = threading.Thread(target=run, daemon=True) | |
| self._thread.start() | |
| return self | |
| def __exit__(self, exc_type, exc, tb): | |
| if self._stop is not None: | |
| self._stop.set() | |
| if self._thread is not None: | |
| self._thread.join(timeout=1.0) | |
| return False | |
| # ----------------------------- | |
| # Gemma 4 MoE mutation/stabilization | |
| # ----------------------------- | |
| MOE_PARAM_MARKERS = ( | |
| "router.", | |
| "experts.", | |
| "post_feedforward_layernorm_1.", | |
| "post_feedforward_layernorm_2.", | |
| "pre_feedforward_layernorm_2.", | |
| ) | |
| def is_new_moe_key(name: str) -> bool: | |
| return any(marker in name for marker in MOE_PARAM_MARKERS) | |
| def _tensor_from_parameter_or_module(obj: Any, label: str) -> torch.Tensor: | |
| if isinstance(obj, nn.Parameter) or torch.is_tensor(obj): | |
| return obj | |
| if isinstance(obj, nn.Module) and hasattr(obj, "weight"): | |
| weight = getattr(obj, "weight") | |
| if isinstance(weight, nn.Parameter) or torch.is_tensor(weight): | |
| return weight | |
| raise TypeError(f"Expected {label} to be a tensor/Parameter or module with .weight; got {type(obj)!r}") | |
| def _router_per_expert_scale(layer: nn.Module) -> Optional[torch.Tensor]: | |
| router = getattr(layer, "router", None) | |
| if router is not None and hasattr(router, "per_expert_scale"): | |
| return router.per_expert_scale | |
| experts = getattr(layer, "experts", None) | |
| if experts is not None and hasattr(experts, "per_expert_scale"): | |
| return experts.per_expert_scale | |
| return None | |
| def _router_scale(layer: nn.Module) -> Optional[torch.Tensor]: | |
| router = getattr(layer, "router", None) | |
| if router is not None and hasattr(router, "scale"): | |
| return router.scale | |
| return None | |
| def validate_loading_info(loading_info: Optional[dict[str, Any]]) -> None: | |
| """Fail closed unless the only newly initialized weights are the added MoE path.""" | |
| if not loading_info: | |
| return | |
| missing = list(loading_info.get("missing_keys", []) or []) | |
| unexpected = list(loading_info.get("unexpected_keys", []) or []) | |
| mismatched = list(loading_info.get("mismatched_keys", []) or []) | |
| error_msgs = list(loading_info.get("error_msgs", []) or []) | |
| bad_missing = [k for k in missing if not is_new_moe_key(k)] | |
| if error_msgs or unexpected or mismatched or bad_missing: | |
| summary = { | |
| "bad_missing_keys": bad_missing[:50], | |
| "unexpected_keys": unexpected[:50], | |
| "mismatched_keys": [str(x) for x in mismatched[:50]], | |
| "error_msgs": error_msgs[:10], | |
| "num_missing_keys": len(missing), | |
| } | |
| raise RuntimeError( | |
| "Checkpoint load did not match the intended dense-backbone + fresh-MoE mutation. " | |
| "The script only permits missing newly added MoE keys. Details: " | |
| + json.dumps(summary, indent=2) | |
| ) | |
| def get_text_config(config: Any) -> Any: | |
| return getattr(config, "text_config", config) | |
| class DenseBranchIdentityRMSNormBridge(nn.Module): | |
| """Identity-preserving replacement for Gemma4's inserted dense-side MoE RMSNorm. | |
| Native Gemma4 additive MoE runs post_feedforward_layernorm_1 on the dense MLP | |
| output before adding the expert branch. Gemma4RMSNorm with weight=1 is not an | |
| identity map, so merely making the expert branch tiny does not preserve the | |
| frozen dense model. This bridge keeps the same .weight parameter name/shape: | |
| output = x + alpha * (rms_norm(x) - x) | |
| alpha=0 exactly preserves the dense path; alpha remains trainable. | |
| """ | |
| def __init__(self, dim: int, eps: float = 1e-6, init: float = 0.0): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.full((dim,), float(init))) | |
| def _norm(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| mean_squared = hidden_states.float().pow(2).mean(-1, keepdim=True) + self.eps | |
| return (hidden_states.float() * torch.pow(mean_squared, -0.5)).type_as(hidden_states) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| normed = self._norm(hidden_states) | |
| alpha = self.weight.float().view(*([1] * (hidden_states.ndim - 1)), -1) | |
| return (hidden_states.float() + alpha * (normed.float() - hidden_states.float())).type_as(hidden_states) | |
| def _is_identity_bridge(module: Any) -> bool: | |
| return isinstance(module, DenseBranchIdentityRMSNormBridge) or module.__class__.__name__ == "DenseBranchIdentityRMSNormBridge" | |
| @torch.no_grad() | |
| def patch_dense_branch_norms_for_upcycling(model: nn.Module, args: argparse.Namespace) -> int: | |
| if args.dense_branch_norm_mode != "identity_residual_rmsnorm": | |
| return 0 | |
| count = 0 | |
| for layer in iter_moe_decoder_layers(model): | |
| old = getattr(layer, "post_feedforward_layernorm_1", None) | |
| if old is None or _is_identity_bridge(old): | |
| continue | |
| if not hasattr(old, "weight"): | |
| raise RuntimeError("post_feedforward_layernorm_1 has no .weight; cannot preserve dense path safely.") | |
| hidden = int(old.weight.numel()) | |
| eps = float(getattr(old, "eps", getattr(get_text_config(model.config), "rms_norm_eps", 1e-6))) | |
| bridge = DenseBranchIdentityRMSNormBridge(hidden, eps=eps, init=args.dense_branch_norm_init) | |
| bridge.to(device=old.weight.device, dtype=old.weight.dtype) | |
| setattr(layer, "post_feedforward_layernorm_1", bridge) | |
| count += 1 | |
| return count | |
| def mutate_config_for_additive_moe(args: argparse.Namespace) -> Any: | |
| if args.num_experts <= 0: | |
| raise ValueError("--num_experts must be positive") | |
| if args.top_k_experts <= 0 or args.top_k_experts > args.num_experts: | |
| raise ValueError("--top_k_experts must be in [1, num_experts]") | |
| if args.moe_intermediate_size <= 0: | |
| raise ValueError("--moe_intermediate_size must be positive") | |
| config = AutoConfig.from_pretrained(args.base_model, token=args.hf_token or None) | |
| text_config = get_text_config(config) | |
| # Gemma 4 dense checkpoints ship with enable_moe_block=False and no expert sizes. | |
| # Setting these public Gemma4TextConfig fields causes Transformers to instantiate | |
| # Gemma4TextRouter, Gemma4TextExperts, and the three MoE-specific RMSNorms in | |
| # every decoder layer. | |
| text_config.enable_moe_block = True | |
| text_config.num_experts = args.num_experts | |
| text_config.top_k_experts = args.top_k_experts | |
| text_config.moe_intermediate_size = args.moe_intermediate_size | |
| text_config.use_double_wide_mlp = False | |
| text_config.dense_branch_norm_mode = args.dense_branch_norm_mode | |
| text_config.dense_branch_norm_init = args.dense_branch_norm_init | |
| text_config.moe_branch_norm_init = args.moe_branch_norm_init | |
| if hasattr(text_config, "expert_intermediate_size"): | |
| # Current HF Gemma4 uses moe_intermediate_size. Keep the legacy field null so | |
| # a stale alias cannot silently disagree with the live field. | |
| text_config.expert_intermediate_size = None | |
| # Training script is text-only; leave vision/audio config intact and frozen. | |
| return config | |
| def _call_from_pretrained(model_cls: Any, base_model: str, load_kwargs: dict[str, Any]): | |
| try: | |
| return model_cls.from_pretrained(base_model, dtype=torch.bfloat16, **load_kwargs) | |
| except TypeError: | |
| # Older Transformers versions use torch_dtype. | |
| return model_cls.from_pretrained(base_model, torch_dtype=torch.bfloat16, **load_kwargs) | |
| def load_upcycled_model(args: argparse.Namespace, config: Any) -> nn.Module: | |
| # Missing keys are expected for the newly added routers/experts/MoE norms. | |
| # Shape mismatches are not expected and are treated as a hard error. | |
| load_kwargs = dict( | |
| config=config, | |
| low_cpu_mem_usage=True, | |
| output_loading_info=True, | |
| ) | |
| if args.hf_token: | |
| load_kwargs["token"] = args.hf_token | |
| if args.attn_implementation: | |
| load_kwargs["attn_implementation"] = args.attn_implementation | |
| if args.ignore_mismatched_sizes: | |
| load_kwargs["ignore_mismatched_sizes"] = True | |
| model_cls = resolve_model_class() | |
| try: | |
| loaded = _call_from_pretrained(model_cls, args.base_model, load_kwargs) | |
| except Exception as exc: | |
| if args.attn_implementation == "flash_attention_2": | |
| print(f"[load] FlashAttention load failed; retrying with sdpa. Error: {exc}") | |
| load_kwargs["attn_implementation"] = "sdpa" | |
| loaded = _call_from_pretrained(model_cls, args.base_model, load_kwargs) | |
| else: | |
| raise | |
| if isinstance(loaded, tuple): | |
| model, loading_info = loaded | |
| else: | |
| model, loading_info = loaded, None | |
| validate_loading_info(loading_info) | |
| return model | |
| def iter_moe_decoder_layers(model: nn.Module) -> Iterator[nn.Module]: | |
| for module in model.modules(): | |
| if hasattr(module, "router") and hasattr(module, "experts") and hasattr(module, "post_feedforward_layernorm_2"): | |
| yield module | |
| def assert_gemma4_additive_moe_contract(model: nn.Module, args: argparse.Namespace) -> None: | |
| """Verify the loaded model implements the Gemma 4 additive MoE layer we intend to train.""" | |
| cfg = get_text_config(getattr(model, "config", None)) | |
| if cfg is None or not getattr(cfg, "enable_moe_block", False): | |
| raise RuntimeError("Model config does not have text_config.enable_moe_block=True after mutation.") | |
| if getattr(cfg, "num_experts", None) != args.num_experts: | |
| raise RuntimeError(f"num_experts mismatch: config={getattr(cfg, 'num_experts', None)} args={args.num_experts}") | |
| if getattr(cfg, "top_k_experts", None) != args.top_k_experts: | |
| raise RuntimeError(f"top_k_experts mismatch: config={getattr(cfg, 'top_k_experts', None)} args={args.top_k_experts}") | |
| if getattr(cfg, "moe_intermediate_size", None) != args.moe_intermediate_size: | |
| raise RuntimeError( | |
| f"moe_intermediate_size mismatch: config={getattr(cfg, 'moe_intermediate_size', None)} " | |
| f"args={args.moe_intermediate_size}" | |
| ) | |
| if args.top_k_experts > args.num_experts: | |
| raise RuntimeError("top_k_experts must be <= num_experts.") | |
| layers = list(iter_moe_decoder_layers(model)) | |
| expected_layers = getattr(cfg, "num_hidden_layers", None) | |
| if expected_layers is not None and len(layers) != expected_layers: | |
| raise RuntimeError(f"Expected {expected_layers} MoE decoder layers; found {len(layers)}.") | |
| if not layers: | |
| raise RuntimeError("No Gemma 4 MoE decoder layers were found.") | |
| for i, layer in enumerate(layers): | |
| for attr in ("router", "experts", "post_feedforward_layernorm_1", "pre_feedforward_layernorm_2", "post_feedforward_layernorm_2"): | |
| if not hasattr(layer, attr): | |
| raise RuntimeError(f"MoE layer {i} is missing required attribute {attr!r}.") | |
| if not hasattr(layer.router, "proj") or not hasattr(layer.router.proj, "weight"): | |
| raise RuntimeError(f"MoE layer {i} router is missing proj.weight.") | |
| router_w = layer.router.proj.weight | |
| if router_w.shape[0] != args.num_experts: | |
| raise RuntimeError(f"MoE layer {i} router proj out dim {router_w.shape[0]} != num_experts {args.num_experts}.") | |
| if _router_scale(layer) is None: | |
| raise RuntimeError(f"MoE layer {i} router is missing trainable scale.") | |
| per_expert = _router_per_expert_scale(layer) | |
| if per_expert is None or per_expert.numel() != args.num_experts: | |
| raise RuntimeError(f"MoE layer {i} per_expert_scale is missing or has wrong size.") | |
| gate_up = _tensor_from_parameter_or_module(layer.experts.gate_up_proj, f"layer {i} experts.gate_up_proj") | |
| down = _tensor_from_parameter_or_module(layer.experts.down_proj, f"layer {i} experts.down_proj") | |
| if args.num_experts not in gate_up.shape or args.num_experts not in down.shape: | |
| raise RuntimeError(f"MoE layer {i} expert tensors do not include num_experts dimension {args.num_experts}.") | |
| if (args.moe_intermediate_size not in gate_up.shape) and (2 * args.moe_intermediate_size not in gate_up.shape): | |
| raise RuntimeError(f"MoE layer {i} gate_up tensor shape {tuple(gate_up.shape)} does not match moe_intermediate_size.") | |
| if args.moe_intermediate_size not in down.shape: | |
| raise RuntimeError(f"MoE layer {i} down tensor shape {tuple(down.shape)} does not match moe_intermediate_size.") | |
| @torch.no_grad() | |
| def stabilize_new_moe_parameters(model: nn.Module, args: argparse.Namespace) -> int: | |
| """Initialize new additive MoE path so it cannot dominate at step zero. | |
| In Gemma 4's MoE decoder layer, the expert result is RMS-normalized by | |
| post_feedforward_layernorm_2 before being added to the dense MLP path. Small | |
| random expert weights alone are not sufficient, because RMSNorm removes most | |
| amplitude information. The stabilizing gate is therefore the scale of | |
| post_feedforward_layernorm_2.weight, initialized near zero but not exactly zero | |
| so gradients still reach expert weights immediately. | |
| """ | |
| count = 0 | |
| for layer in iter_moe_decoder_layers(model): | |
| count += 1 | |
| nn.init.normal_( | |
| _tensor_from_parameter_or_module(layer.experts.gate_up_proj, "experts.gate_up_proj"), | |
| mean=0.0, | |
| std=args.expert_init_std, | |
| ) | |
| nn.init.normal_( | |
| _tensor_from_parameter_or_module(layer.experts.down_proj, "experts.down_proj"), | |
| mean=0.0, | |
| std=args.expert_init_std, | |
| ) | |
| nn.init.normal_(layer.router.proj.weight, mean=0.0, std=args.router_init_std) | |
| router_scale = _router_scale(layer) | |
| if router_scale is not None: | |
| router_scale.fill_(1.0) | |
| per_expert = _router_per_expert_scale(layer) | |
| if per_expert is not None: | |
| per_expert.fill_(args.router_per_expert_scale_init) | |
| if hasattr(layer, "post_feedforward_layernorm_1"): | |
| if _is_identity_bridge(layer.post_feedforward_layernorm_1): | |
| layer.post_feedforward_layernorm_1.weight.fill_(args.dense_branch_norm_init) | |
| else: | |
| layer.post_feedforward_layernorm_1.weight.fill_(1.0) | |
| if hasattr(layer, "pre_feedforward_layernorm_2"): | |
| layer.pre_feedforward_layernorm_2.weight.fill_(1.0) | |
| if hasattr(layer, "post_feedforward_layernorm_2"): | |
| layer.post_feedforward_layernorm_2.weight.fill_(args.moe_branch_norm_init) | |
| return count | |
| def is_trainable_moe_parameter(name: str) -> bool: | |
| return is_new_moe_key(name) | |
| def assert_only_moe_parameters_trainable(model: nn.Module) -> None: | |
| bad_trainable = [n for n, p in model.named_parameters() if p.requires_grad and not is_trainable_moe_parameter(n)] | |
| bad_frozen_moe = [n for n, p in model.named_parameters() if (not p.requires_grad) and is_trainable_moe_parameter(n)] | |
| if bad_trainable or bad_frozen_moe: | |
| raise RuntimeError( | |
| "Trainable-parameter mask is wrong: " | |
| + json.dumps( | |
| { | |
| "non_moe_trainable": bad_trainable[:50], | |
| "moe_frozen": bad_frozen_moe[:50], | |
| "non_moe_trainable_count": len(bad_trainable), | |
| "moe_frozen_count": len(bad_frozen_moe), | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| def freeze_dense_unfreeze_moe(model: nn.Module) -> tuple[int, int]: | |
| total = 0 | |
| trainable = 0 | |
| for name, param in model.named_parameters(): | |
| total += param.numel() | |
| requires_grad = is_trainable_moe_parameter(name) | |
| param.requires_grad_(requires_grad) | |
| if requires_grad: | |
| trainable += param.numel() | |
| return total, trainable | |
| @torch.no_grad() | |
| def cast_trainable_parameters(model: nn.Module, args: argparse.Namespace) -> dict[str, int]: | |
| """Keep the frozen backbone bf16 while optionally training fresh MoE weights in fp32.""" | |
| dtype = torch.float32 if args.trainable_param_dtype == "fp32" else torch.bfloat16 | |
| counts: dict[str, int] = {"fp32": 0, "bf16": 0, "other": 0} | |
| for _name, p in model.named_parameters(): | |
| if not p.requires_grad: | |
| continue | |
| if p.dtype != dtype: | |
| p.data = p.data.to(dtype=dtype) | |
| if p.dtype == torch.float32: | |
| counts["fp32"] += p.numel() | |
| elif p.dtype == torch.bfloat16: | |
| counts["bf16"] += p.numel() | |
| else: | |
| counts["other"] += p.numel() | |
| return counts | |
| def expected_moe_trainable_parameters(text_config: Any, args: argparse.Namespace) -> int: | |
| hidden = int(text_config.hidden_size) | |
| layers = int(text_config.num_hidden_layers) | |
| experts = int(args.num_experts) | |
| intermediate = int(args.moe_intermediate_size) | |
| # Per layer: | |
| # experts.gate_up_proj: E * (2I) * H | |
| # experts.down_proj: E * H * I | |
| # router.proj: H * E | |
| # router.scale: H | |
| # router.per_expert: E | |
| # new MoE RMSNorms: 3 * H | |
| per_layer = 3 * experts * intermediate * hidden + hidden * experts + 4 * hidden + experts | |
| return layers * per_layer | |
| def verify_moe_parameter_layout(model: nn.Module, args: argparse.Namespace) -> dict[str, Any]: | |
| """Fail fast if the live Gemma 4 MoE layout is not the expected additive path.""" | |
| text_config = get_text_config(getattr(model, "config", None)) | |
| if getattr(text_config, "enable_moe_block", False) is not True: | |
| raise RuntimeError("Gemma 4 text_config.enable_moe_block is not True after mutation.") | |
| hidden = int(getattr(text_config, "hidden_size")) | |
| moe_layers = list(iter_moe_decoder_layers(model)) | |
| if not moe_layers: | |
| raise RuntimeError("No Gemma 4 MoE decoder layers were instantiated.") | |
| bad_shapes: list[str] = [] | |
| for idx, layer in enumerate(moe_layers): | |
| gate_up = _tensor_from_parameter_or_module(layer.experts.gate_up_proj, f"layer {idx} experts.gate_up_proj") | |
| down = _tensor_from_parameter_or_module(layer.experts.down_proj, f"layer {idx} experts.down_proj") | |
| router_shape = tuple(layer.router.proj.weight.shape) | |
| gate_up_shape = tuple(gate_up.shape) | |
| down_shape = tuple(down.shape) | |
| expected_gate_up = (args.num_experts, 2 * args.moe_intermediate_size, hidden) | |
| expected_down = (args.num_experts, hidden, args.moe_intermediate_size) | |
| expected_router = (args.num_experts, hidden) | |
| if gate_up_shape != expected_gate_up: | |
| bad_shapes.append(f"layer {idx}: gate_up_proj {gate_up_shape} != {expected_gate_up}") | |
| if down_shape != expected_down: | |
| bad_shapes.append(f"layer {idx}: down_proj {down_shape} != {expected_down}") | |
| if router_shape != expected_router: | |
| bad_shapes.append(f"layer {idx}: router.proj.weight {router_shape} != {expected_router}") | |
| for attr in ("post_feedforward_layernorm_1", "post_feedforward_layernorm_2", "pre_feedforward_layernorm_2"): | |
| if not hasattr(layer, attr): | |
| bad_shapes.append(f"layer {idx}: missing {attr}") | |
| elif tuple(getattr(layer, attr).weight.shape) != (hidden,): | |
| got = tuple(getattr(layer, attr).weight.shape) | |
| bad_shapes.append(f"layer {idx}: {attr}.weight shape {got} != {(hidden,)}") | |
| if bad_shapes: | |
| raise RuntimeError("Gemma 4 MoE layout mismatch:\n" + "\n".join(bad_shapes[:50])) | |
| bad_trainable: list[str] = [] | |
| trainable_names: list[str] = [] | |
| forbidden_markers = ( | |
| "embed_tokens", | |
| "lm_head", | |
| "vision", | |
| "audio", | |
| ".self_attn.", | |
| ".mlp.", | |
| "input_layernorm", | |
| "post_attention_layernorm", | |
| "pre_feedforward_layernorm.", | |
| "post_feedforward_layernorm.", | |
| "per_layer_input", | |
| ) | |
| for name, param in model.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| trainable_names.append(name) | |
| if (not is_trainable_moe_parameter(name)) or any(marker in name for marker in forbidden_markers): | |
| bad_trainable.append(name) | |
| if bad_trainable: | |
| raise RuntimeError( | |
| "Dense/backbone parameters are unexpectedly trainable. First offending names:\n" | |
| + "\n".join(bad_trainable[:50]) | |
| ) | |
| actual_expert_params = sum(p.numel() for n, p in model.named_parameters() if "experts." in n) | |
| actual_trainable_params = sum(p.numel() for _n, p in model.named_parameters() if p.requires_grad) | |
| expected_expert_params = len(moe_layers) * args.num_experts * ( | |
| (2 * args.moe_intermediate_size * hidden) + (hidden * args.moe_intermediate_size) | |
| ) | |
| expected_trainable_params = expected_moe_trainable_parameters(text_config, args) | |
| if actual_expert_params != expected_expert_params: | |
| raise RuntimeError( | |
| f"Expert parameter count mismatch: actual {actual_expert_params} != expected {expected_expert_params}" | |
| ) | |
| if actual_trainable_params != expected_trainable_params: | |
| raise RuntimeError( | |
| f"Trainable MoE parameter count mismatch: actual {actual_trainable_params} != expected {expected_trainable_params}" | |
| ) | |
| return { | |
| "verified_moe_layers": len(moe_layers), | |
| "hidden_size": hidden, | |
| "expected_expert_params": expected_expert_params, | |
| "actual_expert_params": actual_expert_params, | |
| "expected_trainable_params": expected_trainable_params, | |
| "actual_trainable_params": actual_trainable_params, | |
| "trainable_parameter_tensors": len(trainable_names), | |
| } | |
| def build_optimizer(model: nn.Module, args: argparse.Namespace) -> torch.optim.Optimizer: | |
| expert_params = [] | |
| router_params = [] | |
| moe_output_norm_params = [] | |
| identity_norm_params = [] | |
| for name, p in model.named_parameters(): | |
| if not p.requires_grad: | |
| continue | |
| if "experts." in name: | |
| expert_params.append(p) | |
| elif "router." in name: | |
| router_params.append(p) | |
| elif "post_feedforward_layernorm_2." in name: | |
| # This is the additive branch's output gate. It starts near zero and | |
| # needs a higher LR than the dense-branch identity norms. | |
| moe_output_norm_params.append(p) | |
| elif "post_feedforward_layernorm_1." in name or "pre_feedforward_layernorm_2." in name: | |
| identity_norm_params.append(p) | |
| else: | |
| raise RuntimeError(f"Unexpected trainable parameter: {name}") | |
| moe_output_norm_lr = args.norm_lr if args.norm_lr is not None else args.moe_output_norm_lr | |
| identity_norm_lr = args.norm_lr if args.norm_lr is not None else args.identity_norm_lr | |
| groups = [] | |
| if expert_params: | |
| groups.append({"name": "experts", "params": expert_params, "lr": args.expert_lr, "weight_decay": args.weight_decay}) | |
| if router_params: | |
| groups.append({"name": "routers", "params": router_params, "lr": args.router_lr, "weight_decay": 0.0}) | |
| if moe_output_norm_params: | |
| groups.append({"name": "moe_output_norm", "params": moe_output_norm_params, "lr": moe_output_norm_lr, "weight_decay": 0.0}) | |
| if identity_norm_params: | |
| groups.append({"name": "identity_norms", "params": identity_norm_params, "lr": identity_norm_lr, "weight_decay": 0.0}) | |
| if not groups: | |
| raise RuntimeError("No trainable MoE parameters found. Check Gemma 4 config mutation and parameter names.") | |
| try: | |
| return torch.optim.AdamW(groups, betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) | |
| except TypeError: | |
| return torch.optim.AdamW(groups, betas=(args.beta1, args.beta2), eps=args.adam_eps) | |
| @torch.no_grad() | |
| def clamp_moe_controls(model: nn.Module, args: argparse.Namespace) -> None: | |
| for layer in iter_moe_decoder_layers(model): | |
| # The routed expert branch is full hidden-width and post-normalized before | |
| # being added to the dense branch. Bound it as a positive residual gate so | |
| # generation cannot collapse into scratch-expert behavior. | |
| if args.moe_branch_norm_clip > 0 and hasattr(layer, "post_feedforward_layernorm_2"): | |
| layer.post_feedforward_layernorm_2.weight.clamp_( | |
| min=0.0, | |
| max=args.moe_branch_norm_clip, | |
| ) | |
| # In identity-residual mode, alpha=0 is exact dense-backbone behavior. | |
| # In native-RMSNorm mode, weight=1 is the native Gemma4 initialization. | |
| if args.dense_branch_norm_clip > 0 and hasattr(layer, "post_feedforward_layernorm_1"): | |
| delta = args.dense_branch_norm_clip | |
| if _is_identity_bridge(layer.post_feedforward_layernorm_1): | |
| layer.post_feedforward_layernorm_1.weight.clamp_(min=-delta, max=delta) | |
| else: | |
| layer.post_feedforward_layernorm_1.weight.clamp_(min=1.0 - delta, max=1.0 + delta) | |
| if args.pre_moe_norm_clip > 0 and hasattr(layer, "pre_feedforward_layernorm_2"): | |
| delta = args.pre_moe_norm_clip | |
| layer.pre_feedforward_layernorm_2.weight.clamp_(min=1.0 - delta, max=1.0 + delta) | |
| if args.router_per_expert_scale_clip > 0: | |
| per_expert = _router_per_expert_scale(layer) | |
| if per_expert is not None: | |
| per_expert.clamp_(min=0.0, max=args.router_per_expert_scale_clip) | |
| if args.router_input_scale_clip > 0: | |
| router_scale = _router_scale(layer) | |
| if router_scale is not None: | |
| delta = args.router_input_scale_clip | |
| router_scale.clamp_(min=max(0.0, 1.0 - delta), max=1.0 + delta) | |
| def moe_branch_scale_summary(model: nn.Module) -> dict[str, float]: | |
| moe_output_vals = [] | |
| dense_identity_vals = [] | |
| expert_input_vals = [] | |
| router_scale_vals = [] | |
| per_expert_scale_vals = [] | |
| for layer in iter_moe_decoder_layers(model): | |
| if hasattr(layer, "post_feedforward_layernorm_2"): | |
| moe_output_vals.append(layer.post_feedforward_layernorm_2.weight.detach().float()) | |
| if hasattr(layer, "post_feedforward_layernorm_1"): | |
| dense_identity_vals.append(layer.post_feedforward_layernorm_1.weight.detach().float()) | |
| if hasattr(layer, "pre_feedforward_layernorm_2"): | |
| expert_input_vals.append(layer.pre_feedforward_layernorm_2.weight.detach().float()) | |
| router_scale_vals.append(layer.router.scale.detach().float()) | |
| per_expert_scale_vals.append(layer.router.per_expert_scale.detach().float()) | |
| def summarize(prefix: str, vals: list[torch.Tensor]) -> dict[str, float]: | |
| if not vals: | |
| return {f"{prefix}_mean": 0.0, f"{prefix}_abs_max": 0.0} | |
| flat = torch.cat([v.reshape(-1) for v in vals]) | |
| return { | |
| f"{prefix}_mean": flat.mean().item(), | |
| f"{prefix}_abs_mean": flat.abs().mean().item(), | |
| f"{prefix}_abs_max": flat.abs().max().item(), | |
| } | |
| out = {} | |
| out.update(summarize("moe_output_norm", moe_output_vals)) | |
| out.update(summarize("dense_identity_norm", dense_identity_vals)) | |
| if dense_identity_vals: | |
| flat_dense = torch.cat([v.reshape(-1) for v in dense_identity_vals]) | |
| bridge_mode = any(_is_identity_bridge(getattr(layer, "post_feedforward_layernorm_1", None)) for layer in iter_moe_decoder_layers(model)) | |
| dense_target = 0.0 if bridge_mode else 1.0 | |
| out["dense_identity_norm_target"] = dense_target | |
| out["dense_identity_norm_dev_abs_mean"] = (flat_dense - dense_target).abs().mean().item() | |
| out["dense_identity_norm_dev_abs_max"] = (flat_dense - dense_target).abs().max().item() | |
| out.update(summarize("expert_input_norm", expert_input_vals)) | |
| out.update(summarize("router_scale", router_scale_vals)) | |
| out.update(summarize("router_per_expert_scale", per_expert_scale_vals)) | |
| return out | |
| class RouterRecorder: | |
| def __init__(self): | |
| self.records: list[tuple[torch.Tensor, torch.Tensor]] = [] | |
| self.handles = [] | |
| self.recording = True | |
| def clear(self) -> None: | |
| self.records.clear() | |
| def attach(self, model: nn.Module) -> None: | |
| def hook(_module, _inputs, output): | |
| # Gemma4TextRouter returns (router_probabilities, top_k_weights, top_k_index). | |
| if not self.recording: | |
| return | |
| if isinstance(output, tuple) and len(output) >= 3: | |
| probs, _weights, top_idx = output[:3] | |
| self.records.append((probs.float(), top_idx)) | |
| for layer in iter_moe_decoder_layers(model): | |
| self.handles.append(layer.router.register_forward_hook(hook)) | |
| def close(self) -> None: | |
| for h in self.handles: | |
| h.remove() | |
| self.handles.clear() | |
| self.clear() | |
| def router_load_balancing_loss(records: list[tuple[torch.Tensor, torch.Tensor]], num_experts: int) -> torch.Tensor: | |
| if not records: | |
| return torch.zeros((), device="cuda" if torch.cuda.is_available() else "cpu") | |
| losses = [] | |
| for probs, top_idx in records: | |
| if probs.numel() == 0: | |
| continue | |
| probs = probs.float() | |
| density_proxy = probs.mean(dim=0) # [E] | |
| density = F.one_hot(top_idx, num_classes=num_experts).float().mean(dim=(0, 1)) # [E] | |
| losses.append(num_experts * torch.sum(density_proxy * density)) | |
| if not losses: | |
| return torch.zeros((), device=records[0][0].device) | |
| return torch.stack(losses).mean() | |
| def router_entropy_metric(records: list[tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor: | |
| if not records: | |
| return torch.zeros((), device="cuda" if torch.cuda.is_available() else "cpu") | |
| ents = [] | |
| for probs, _ in records: | |
| probs = probs.float().clamp_min(1e-9) | |
| ents.append(-(probs * probs.log()).sum(dim=-1).mean()) | |
| return torch.stack(ents).mean() if ents else torch.zeros((), device=records[0][0].device) | |
| def branch_clip_regularizer(model: nn.Module, args: argparse.Namespace) -> torch.Tensor: | |
| if args.branch_clip_reg_coef <= 0 or args.moe_branch_norm_clip <= 0: | |
| device = next(model.parameters()).device | |
| return torch.zeros((), device=device) | |
| regs = [] | |
| threshold = args.moe_branch_norm_clip * 0.90 | |
| for layer in iter_moe_decoder_layers(model): | |
| if hasattr(layer, "post_feedforward_layernorm_2"): | |
| w = layer.post_feedforward_layernorm_2.weight.float().abs() | |
| regs.append(torch.relu(w - threshold).pow(2).mean()) | |
| if not regs: | |
| device = next(model.parameters()).device | |
| return torch.zeros((), device=device) | |
| return torch.stack(regs).mean() | |
| def norm_anchor_regularizer(model: nn.Module, args: argparse.Namespace) -> torch.Tensor: | |
| """Keep newly inserted multiplicative controls close to their safe defaults. | |
| The expert-output norm is intentionally not anchored to 1.0; it is the learned | |
| additive-branch gate. The dense-side bridge norm, pre-expert norm, router input | |
| scale, and per-expert route multipliers should remain close to identity unless | |
| the CLM objective consistently proves otherwise. | |
| """ | |
| if args.norm_anchor_reg_coef <= 0: | |
| device = next(model.parameters()).device | |
| return torch.zeros((), device=device) | |
| regs = [] | |
| for layer in iter_moe_decoder_layers(model): | |
| if hasattr(layer, "post_feedforward_layernorm_1"): | |
| dense_target = 0.0 if _is_identity_bridge(layer.post_feedforward_layernorm_1) else 1.0 | |
| regs.append((layer.post_feedforward_layernorm_1.weight.float() - dense_target).pow(2).mean()) | |
| if hasattr(layer, "pre_feedforward_layernorm_2"): | |
| regs.append(0.25 * (layer.pre_feedforward_layernorm_2.weight.float() - 1.0).pow(2).mean()) | |
| router_scale = _router_scale(layer) | |
| if router_scale is not None: | |
| regs.append(0.25 * (router_scale.float() - 1.0).pow(2).mean()) | |
| per_expert = _router_per_expert_scale(layer) | |
| if per_expert is not None: | |
| regs.append(0.10 * (per_expert.float() - args.router_per_expert_scale_init).pow(2).mean()) | |
| if not regs: | |
| device = next(model.parameters()).device | |
| return torch.zeros((), device=device) | |
| return torch.stack(regs).mean() | |
| def preflight_forward(model: nn.Module, tokenizer: Any, device: torch.device) -> dict[str, Any]: | |
| was_training = model.training | |
| model.eval() | |
| ids = tokenizer.encode( | |
| "Sanity check: write a one-line Python function.", | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| ids = ids[:, :64].to(device) | |
| attention_mask = torch.ones_like(ids) | |
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| out = model(input_ids=ids, attention_mask=attention_mask, labels=ids, return_dict=True) | |
| loss = out.loss.detach().float().item() if out.loss is not None else float("nan") | |
| if not math.isfinite(loss): | |
| raise RuntimeError(f"Preflight forward produced non-finite loss: {loss}") | |
| if was_training: | |
| model.train() | |
| return {"preflight_seq_len": int(ids.shape[-1]), "preflight_loss": loss} | |
| @contextlib.contextmanager | |
| def temporarily_set_moe_enabled(model: nn.Module, enabled: bool): | |
| old_values: list[tuple[nn.Module, bool]] = [] | |
| for layer in iter_moe_decoder_layers(model): | |
| if hasattr(layer, "enable_moe_block"): | |
| old_values.append((layer, bool(layer.enable_moe_block))) | |
| layer.enable_moe_block = enabled | |
| try: | |
| yield | |
| finally: | |
| for layer, old in old_values: | |
| layer.enable_moe_block = old | |
| @torch.no_grad() | |
| def dense_equivalence_check(model: nn.Module, tokenizer: Any, args: argparse.Namespace, device: torch.device) -> dict[str, Any]: | |
| """Compare initialized MoE-on logits against the same model with MoE disabled. | |
| This avoids loading a second 31B model. With the identity-residual dense bridge | |
| and near-zero expert output gate, the enabled and disabled paths should agree | |
| closely before training. | |
| """ | |
| was_training = model.training | |
| model.eval() | |
| prompt = ( | |
| "You are a software engineering assistant. Write a minimal Python function " | |
| "that returns the SHA256 hex digest of a string." | |
| ) | |
| ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")[:, : args.dense_equivalence_seq_len].to(device) | |
| attn = torch.ones_like(ids) | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| with temporarily_set_moe_enabled(model, False): | |
| dense_logits = model(input_ids=ids, attention_mask=attn, return_dict=True).logits.float() | |
| with temporarily_set_moe_enabled(model, True): | |
| moe_logits = model(input_ids=ids, attention_mask=attn, return_dict=True).logits.float() | |
| keep = min(8, dense_logits.shape[1]) | |
| dense_slice = dense_logits[:, -keep:, :] | |
| moe_slice = moe_logits[:, -keep:, :] | |
| diff = (moe_slice - dense_slice).abs() | |
| dense_prob = torch.softmax(dense_slice, dim=-1) | |
| moe_log_prob = torch.log_softmax(moe_slice, dim=-1) | |
| dense_log_prob = torch.log_softmax(dense_slice, dim=-1) | |
| kl = (dense_prob * (dense_log_prob - moe_log_prob)).sum(dim=-1).mean() | |
| dense_rms = dense_slice.pow(2).mean().sqrt().clamp_min(1e-8) | |
| rel_rmse = ((moe_slice - dense_slice).pow(2).mean().sqrt() / dense_rms).detach() | |
| result = { | |
| "dense_equivalence_positions": int(keep), | |
| "dense_equivalence_max_abs_logit_diff": float(diff.max().detach().cpu()), | |
| "dense_equivalence_mean_abs_logit_diff": float(diff.mean().detach().cpu()), | |
| "dense_equivalence_mean_kl_dense_to_moe": float(kl.detach().cpu()), | |
| "dense_equivalence_rel_rmse": float(rel_rmse.cpu()), | |
| "threshold_max_abs_logit_diff": args.dense_equivalence_max_abs_logit_diff, | |
| "threshold_mean_kl": args.dense_equivalence_max_mean_kl, | |
| } | |
| if was_training: | |
| model.train() | |
| if result["dense_equivalence_max_abs_logit_diff"] > args.dense_equivalence_max_abs_logit_diff: | |
| raise RuntimeError("Upcycled model differs too much from MoE-disabled dense path at initialization: " + json.dumps(result, indent=2)) | |
| if result["dense_equivalence_mean_kl_dense_to_moe"] > args.dense_equivalence_max_mean_kl: | |
| raise RuntimeError("Upcycled model KL divergence from MoE-disabled dense path is too high at initialization: " + json.dumps(result, indent=2)) | |
| return result | |
| CUSTOM_MODEL_CODE = '"""Custom loader for Gemma4 additive-MoE upcycled checkpoints.\n\nThis file preserves the identity-residual dense-side MoE bridge used during\ntraining. Load the saved model with trust_remote_code=True.\n"""\n\nimport torch\nfrom torch import nn\ntry:\n from transformers import Gemma4ForConditionalGeneration\nexcept ImportError: # pragma: no cover - compatibility with dev/nightly Transformers exports\n from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration\n\n\nclass DenseBranchIdentityRMSNormBridge(nn.Module):\n def __init__(self, dim: int, eps: float = 1e-6, init: float = 0.0):\n super().__init__()\n self.eps = eps\n self.weight = nn.Parameter(torch.full((dim,), float(init)))\n\n def _norm(self, hidden_states: torch.Tensor) -> torch.Tensor:\n mean_squared = hidden_states.float().pow(2).mean(-1, keepdim=True) + self.eps\n return (hidden_states.float() * torch.pow(mean_squared, -0.5)).type_as(hidden_states)\n\n def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n normed = self._norm(hidden_states)\n alpha = self.weight.float().view(*([1] * (hidden_states.ndim - 1)), -1)\n return (hidden_states.float() + alpha * (normed.float() - hidden_states.float())).type_as(hidden_states)\n\n\ndef _get_text_config(config):\n return getattr(config, "text_config", config)\n\n\ndef _iter_moe_decoder_layers(model):\n for module in model.modules():\n if hasattr(module, "router") and hasattr(module, "experts") and hasattr(module, "post_feedforward_layernorm_2"):\n yield module\n\n\ndef patch_dense_branch_norms(model):\n text_config = _get_text_config(model.config)\n mode = getattr(text_config, "dense_branch_norm_mode", "identity_residual_rmsnorm")\n if mode != "identity_residual_rmsnorm":\n return 0\n init = float(getattr(text_config, "dense_branch_norm_init", 0.0))\n count = 0\n for layer in _iter_moe_decoder_layers(model):\n old = getattr(layer, "post_feedforward_layernorm_1", None)\n if old is None or old.__class__.__name__ == "DenseBranchIdentityRMSNormBridge":\n continue\n hidden = int(old.weight.numel())\n eps = float(getattr(old, "eps", getattr(text_config, "rms_norm_eps", 1e-6)))\n bridge = DenseBranchIdentityRMSNormBridge(hidden, eps=eps, init=init)\n bridge.to(device=old.weight.device, dtype=old.weight.dtype)\n setattr(layer, "post_feedforward_layernorm_1", bridge)\n count += 1\n return count\n\n\nclass Gemma4AdditiveMoEForConditionalGeneration(Gemma4ForConditionalGeneration):\n def __init__(self, config):\n super().__init__(config)\n patch_dense_branch_norms(self)\n' | |
| def write_custom_model_code(output_dir: Path, args: argparse.Namespace) -> None: | |
| if args.dense_branch_norm_mode != "identity_residual_rmsnorm": | |
| return | |
| (output_dir / "modeling_gemma4_additive_moe_upcycled.py").write_text(CUSTOM_MODEL_CODE, encoding="utf-8") | |
| (output_dir / "__init__.py").write_text("", encoding="utf-8") | |
| config_path = output_dir / "config.json" | |
| cfg = json.loads(config_path.read_text(encoding="utf-8")) | |
| text_cfg = cfg.get("text_config", cfg) | |
| text_cfg["dense_branch_norm_mode"] = args.dense_branch_norm_mode | |
| text_cfg["dense_branch_norm_init"] = args.dense_branch_norm_init | |
| text_cfg["moe_branch_norm_init"] = args.moe_branch_norm_init | |
| cfg["architectures"] = ["Gemma4AdditiveMoEForConditionalGeneration"] | |
| cfg.setdefault("auto_map", {}) | |
| auto_cls = "modeling_gemma4_additive_moe_upcycled.Gemma4AdditiveMoEForConditionalGeneration" | |
| cfg["auto_map"].update({ | |
| "AutoModel": auto_cls, | |
| "AutoModelForCausalLM": auto_cls, | |
| "AutoModelForImageTextToText": auto_cls, | |
| "AutoModelForMultimodalLM": auto_cls, | |
| }) | |
| config_path.write_text(json.dumps(cfg, indent=2), encoding="utf-8") | |
| # ----------------------------- | |
| # Saving / smoke test | |
| # ----------------------------- | |
| def save_full_model(model: nn.Module, processor: Any, args: argparse.Namespace, output_dir: Path, step: int) -> None: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Persist the artifact as an inference-ready model. Training disables cache for | |
| # gradient checkpointing, but generated models should load with cache enabled. | |
| cache_states = [] | |
| candidate_configs = [getattr(model, "config", None)] | |
| if getattr(model, "config", None) is not None: | |
| candidate_configs.append(getattr(model.config, "text_config", None)) | |
| if hasattr(model, "language_model"): | |
| candidate_configs.append(getattr(model.language_model, "config", None)) | |
| for cfg in candidate_configs: | |
| if cfg is not None and hasattr(cfg, "use_cache"): | |
| cache_states.append((cfg, cfg.use_cache)) | |
| cfg.use_cache = True | |
| try: | |
| model.save_pretrained(output_dir, safe_serialization=True, max_shard_size=args.max_shard_size) | |
| processor.save_pretrained(output_dir) | |
| tokenizer = getattr(processor, "tokenizer", None) | |
| if tokenizer is not None: | |
| tokenizer.save_pretrained(output_dir) | |
| write_custom_model_code(output_dir, args) | |
| finally: | |
| for cfg, old_value in cache_states: | |
| cfg.use_cache = old_value | |
| meta = { | |
| "base_model": args.base_model, | |
| "step": step, | |
| "num_experts": args.num_experts, | |
| "top_k_experts": args.top_k_experts, | |
| "moe_intermediate_size": args.moe_intermediate_size, | |
| "frozen_backbone": True, | |
| "stability_controls": { | |
| "moe_branch_norm_init": args.moe_branch_norm_init, | |
| "moe_branch_norm_clip": args.moe_branch_norm_clip, | |
| "dense_branch_norm_mode": args.dense_branch_norm_mode, | |
| "dense_branch_norm_init": args.dense_branch_norm_init, | |
| "dense_branch_norm_clip": args.dense_branch_norm_clip, | |
| "pre_moe_norm_clip": args.pre_moe_norm_clip, | |
| "router_input_scale_clip": args.router_input_scale_clip, | |
| "router_per_expert_scale_clip": args.router_per_expert_scale_clip, | |
| "norm_anchor_reg_coef": args.norm_anchor_reg_coef, | |
| "branch_clip_reg_coef": args.branch_clip_reg_coef, | |
| }, | |
| "trainable_parameter_filter": [ | |
| "router.*", | |
| "experts.*", | |
| "post_feedforward_layernorm_1.*", | |
| "post_feedforward_layernorm_2.*", | |
| "pre_feedforward_layernorm_2.*", | |
| ], | |
| "dataset_mix": [dataclasses.asdict(s) for s in build_dataset_mix(args)], | |
| "include_nebius_swe_agent": bool(args.include_nebius_swe_agent), | |
| "include_tau2_tool": bool(args.include_tau2_tool), | |
| "only_dataset_kinds": args.only_dataset_kinds, | |
| "exclude_dataset_kinds": args.exclude_dataset_kinds, | |
| "allow_hf_parquet_fallback": bool(args.allow_hf_parquet_fallback), | |
| "dataset_load_timeout_sec": int(args.dataset_load_timeout_sec), | |
| "trainable_param_dtype": args.trainable_param_dtype, | |
| "script_validation": { | |
| "load_fails_unless_only_moe_keys_are_missing": True, | |
| "asserts_all_text_layers_have_router_experts_and_additive_norms": True, | |
| "asserts_only_moe_parameters_are_trainable": True, | |
| "uses_identity_residual_dense_bridge": args.dense_branch_norm_mode == "identity_residual_rmsnorm", | |
| "dense_equivalence_check_is_moe_enabled_vs_disabled": True, | |
| }, | |
| "saved_at_unix": time.time(), | |
| } | |
| (output_dir / "additive_moe_upcycling_meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8") | |
| def smoke_generate(model: nn.Module, processor: Any, device: torch.device, max_new_tokens: int) -> str: | |
| model.eval() | |
| prompt = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Write a concise Python function named safe_divide(a, b) that returns None on division by zero.", | |
| } | |
| ], | |
| } | |
| ] | |
| try: | |
| inputs = processor.apply_chat_template( | |
| prompt, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| except Exception: | |
| tokenizer = getattr(processor, "tokenizer", processor) | |
| text = "User: Write a concise Python function named safe_divide(a, b) that returns None on division by zero.\nAssistant:" | |
| ids = tokenizer.encode(text, return_tensors="pt") | |
| inputs = {"input_ids": ids, "attention_mask": torch.ones_like(ids)} | |
| inputs = {k: v.to(device) for k, v in inputs.items() if torch.is_tensor(v)} | |
| with torch.no_grad(): | |
| out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
| input_len = inputs["input_ids"].shape[-1] | |
| tokenizer = getattr(processor, "tokenizer", processor) | |
| decoded = tokenizer.decode(out[0][input_len:], skip_special_tokens=False) | |
| model.train() | |
| return decoded | |
| # ----------------------------- | |
| # CLI / main training loop | |
| # ----------------------------- | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Upcycle google/gemma-4-31B-it into a frozen-backbone additive MoE model.") | |
| p.add_argument("--base_model", default="google/gemma-4-31B-it") | |
| p.add_argument("--output_dir", default="./gemma-4-31b-it-additive-moe-upcycled") | |
| p.add_argument("--hf_token", default=os.environ.get("HF_TOKEN", None)) | |
| p.add_argument("--seed", type=int, default=1337) | |
| # MoE sizing: fits a single 288GB B300 while leaving room for AdamW states and activations. | |
| p.add_argument("--num_experts", type=int, default=8) | |
| p.add_argument("--top_k_experts", type=int, default=2) | |
| p.add_argument("--moe_intermediate_size", type=int, default=512) | |
| # Stabilization knobs. | |
| p.add_argument("--expert_init_std", type=float, default=0.02) | |
| p.add_argument("--router_init_std", type=float, default=0.002) | |
| p.add_argument("--router_per_expert_scale_init", type=float, default=1.0) | |
| p.add_argument("--router_per_expert_scale_clip", type=float, default=2.0) | |
| p.add_argument("--router_input_scale_clip", type=float, default=0.50) | |
| p.add_argument("--moe_branch_norm_init", type=float, default=1e-3) | |
| p.add_argument("--moe_branch_norm_clip", type=float, default=0.35) | |
| p.add_argument("--dense_branch_norm_mode", choices=["identity_residual_rmsnorm", "native_rmsnorm"], default="identity_residual_rmsnorm") | |
| p.add_argument("--dense_branch_norm_init", type=float, default=0.0) | |
| p.add_argument("--dense_branch_norm_clip", type=float, default=0.10) | |
| p.add_argument("--pre_moe_norm_clip", type=float, default=0.25) | |
| p.add_argument("--branch_clip_reg_coef", type=float, default=0.01) | |
| p.add_argument("--norm_anchor_reg_coef", type=float, default=0.02) | |
| p.add_argument("--router_aux_coef", type=float, default=0.01) | |
| p.add_argument("--max_init_shrink_steps", type=int, default=8) | |
| p.add_argument("--init_shrink_factor", type=float, default=0.5) | |
| p.add_argument("--trainable_param_dtype", choices=["fp32", "bf16"], default="fp32") | |
| # Training budget. 12k optimizer steps * 32 * 4096 ~= 1.57B training tokens. | |
| # On a single B300 at the measured ~950 tokens/s, a <24h run is roughly | |
| # 70-80M tokens, not 1.57B. Use --target_tokens for wall-clock budgeting. | |
| p.add_argument("--seq_len", type=int, default=4096) | |
| p.add_argument("--micro_batch_size", type=int, default=1) | |
| p.add_argument("--grad_accum_steps", type=int, default=32) | |
| p.add_argument("--max_steps", type=int, default=12_000) | |
| p.add_argument("--target_tokens", type=int, default=0, help="If >0, override max_steps = ceil(target_tokens / (seq_len * micro_batch_size * grad_accum_steps)).") | |
| p.add_argument("--warmup_steps", type=int, default=500) | |
| p.add_argument("--auto_shorten_warmup", action=argparse.BooleanOptionalAction, default=True, help="If warmup_steps is too large for a shortened target-token run, cap it to about 10% of max_steps.") | |
| p.add_argument("--min_lr_ratio", type=float, default=0.10) | |
| p.add_argument("--max_grad_norm", type=float, default=1.0) | |
| # Separate learning rates: scratch experts need the highest LR; router/norms should move slower. | |
| p.add_argument("--expert_lr", type=float, default=2.0e-4) | |
| p.add_argument("--router_lr", type=float, default=5.0e-5) | |
| p.add_argument("--moe_output_norm_lr", type=float, default=1.0e-4) | |
| p.add_argument("--identity_norm_lr", type=float, default=2.0e-5) | |
| p.add_argument("--norm_lr", type=float, default=None, help="Optional override for both MoE norm parameter groups.") | |
| p.add_argument("--weight_decay", type=float, default=0.01) | |
| p.add_argument("--beta1", type=float, default=0.9) | |
| p.add_argument("--beta2", type=float, default=0.95) | |
| p.add_argument("--adam_eps", type=float, default=1e-8) | |
| # Data/tokenization. | |
| p.add_argument("--data_mix_preset", choices=["default", "high_signal_24h"], default="default", help="Use high_signal_24h for a sub-24h single-B300 run that spends tokens on current knowledge, code, SWE-agent, and tool-use rather than broad background text.") | |
| p.add_argument("--max_chars_per_sample", type=int, default=160_000) | |
| p.add_argument("--min_chars_per_sample", type=int, default=128) | |
| p.add_argument("--max_tokens_per_sample", type=int, default=32_768) | |
| p.add_argument("--shuffle_buffer_size", type=int, default=1024) | |
| p.add_argument("--shuffle_scope", choices=["none", "global", "per_source", "both"], default="global", help="Streaming shuffle placement. 'global' avoids filling one shuffle buffer per source before the first batch.") | |
| p.add_argument("--interleave_impl", choices=["custom"], default="custom", help="Use the robust Python interleaver; retained as an explicit option for reproducible configs.") | |
| p.add_argument("--interleave_stopping_strategy", choices=["first_exhausted", "all_exhausted"], default="all_exhausted") | |
| p.add_argument("--source_probe_examples", type=int, default=1, help="Bounded per-source formatter validation examples before interleave; set 0 to disable.") | |
| p.add_argument("--source_probe_max_scan_rows", type=int, default=512, help="Maximum raw rows to scan per source during bounded formatter validation.") | |
| p.add_argument("--drop_sources_without_probe_examples", action=argparse.BooleanOptionalAction, default=True, help="Drop and renormalize sources whose bounded formatter probe yields no usable text.") | |
| p.add_argument("--include_nebius_swe_agent", action=argparse.BooleanOptionalAction, default=False) | |
| p.add_argument("--include_tau2_tool", action=argparse.BooleanOptionalAction, default=True) | |
| p.add_argument("--only_dataset_kinds", default="", help="Debug/testing: comma-separated dataset kind names to include, e.g. tau2_tool") | |
| p.add_argument("--exclude_dataset_kinds", default="", help="Debug/testing: comma-separated dataset kind names to exclude") | |
| p.add_argument("--allow_hf_parquet_fallback", action=argparse.BooleanOptionalAction, default=True) | |
| p.add_argument("--dataset_load_timeout_sec", type=int, default=300, help="Timeout for load_dataset metadata construction; 0 disables.") | |
| p.add_argument("--data_loader_smoke_test_only", action=argparse.BooleanOptionalAction, default=False) | |
| p.add_argument("--data_loader_smoke_examples", type=int, default=16) | |
| p.add_argument("--data_startup_heartbeat_sec", type=float, default=30.0, help="Progress heartbeat while waiting for the first streamed/packed batch; 0 disables.") | |
| p.add_argument("--packer_startup_log_every_examples", type=int, default=128, help="During startup, log after this many assigned text-stream rows; 0 disables row-count logs.") | |
| p.add_argument("--num_workers", type=int, default=0) | |
| # Runtime. | |
| p.add_argument("--attn_implementation", default="sdpa", choices=["flash_attention_2", "sdpa", "eager", ""]) | |
| p.add_argument("--ignore_mismatched_sizes", action=argparse.BooleanOptionalAction, default=False) | |
| p.add_argument("--gradient_checkpointing", action=argparse.BooleanOptionalAction, default=True) | |
| p.add_argument("--allow_no_gradient_checkpointing_oom_risk", action=argparse.BooleanOptionalAction, default=False, help="Required to intentionally disable gradient checkpointing on a single GPU. Without it, the script refuses the known-OOM configuration.") | |
| p.add_argument("--wall_clock_limit_hours", type=float, default=0.0, help="If >0, stop cleanly after this many wall-clock hours and save the current full model.") | |
| p.add_argument("--wall_clock_save_margin_minutes", type=float, default=15.0, help="Stop this many minutes before the wall-clock limit to leave time for final save.") | |
| p.add_argument("--log_every", type=int, default=10) | |
| p.add_argument("--save_every", type=int, default=0) | |
| p.add_argument("--max_shard_size", default="5GB") | |
| p.add_argument("--preflight_forward", action=argparse.BooleanOptionalAction, default=True) | |
| p.add_argument("--dense_equivalence_check", action=argparse.BooleanOptionalAction, default=True) | |
| p.add_argument("--dense_equivalence_max_abs_logit_diff", type=float, default=2.0) | |
| p.add_argument("--dense_equivalence_max_mean_kl", type=float, default=0.05) | |
| p.add_argument("--dense_equivalence_seq_len", type=int, default=96) | |
| p.add_argument("--smoke_test", action=argparse.BooleanOptionalAction, default=True) | |
| p.add_argument("--smoke_test_tokens", type=int, default=96) | |
| return p.parse_args() | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def normalize_training_budget(args: argparse.Namespace) -> None: | |
| tokens_per_optimizer_step = int(args.seq_len) * int(args.micro_batch_size) * int(args.grad_accum_steps) | |
| if tokens_per_optimizer_step <= 0: | |
| raise ValueError("seq_len, micro_batch_size, and grad_accum_steps must be positive.") | |
| if int(getattr(args, "target_tokens", 0) or 0) > 0: | |
| old_steps = int(args.max_steps) | |
| args.max_steps = max(1, math.ceil(int(args.target_tokens) / tokens_per_optimizer_step)) | |
| print( | |
| "[budget] target_tokens override: " | |
| f"target_tokens={int(args.target_tokens):,} tokens_per_optimizer_step={tokens_per_optimizer_step:,} " | |
| f"max_steps {old_steps:,} -> {args.max_steps:,}", | |
| flush=True, | |
| ) | |
| if getattr(args, "auto_shorten_warmup", True) and int(args.warmup_steps) >= int(args.max_steps): | |
| old_warmup = int(args.warmup_steps) | |
| args.warmup_steps = max(1, min(old_warmup, math.ceil(0.10 * int(args.max_steps)))) | |
| print( | |
| f"[budget] warmup shortened for truncated run: warmup_steps {old_warmup:,} -> {args.warmup_steps:,}", | |
| flush=True, | |
| ) | |
| total_tokens = tokens_per_optimizer_step * int(args.max_steps) | |
| print( | |
| "[budget] effective training budget: " | |
| f"seq_len={args.seq_len} micro_batch_size={args.micro_batch_size} " | |
| f"grad_accum_steps={args.grad_accum_steps} max_steps={args.max_steps:,} " | |
| f"total_tokens={total_tokens:,} warmup_steps={args.warmup_steps:,} " | |
| f"data_mix_preset={getattr(args, 'data_mix_preset', 'default')}", | |
| flush=True, | |
| ) | |
| def main() -> None: | |
| args = parse_args() | |
| # Must be set before CUDA allocation for best effect; also export it in the shell. | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| normalize_training_budget(args) | |
| set_seed(args.seed) | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_float32_matmul_precision("high") | |
| print("[init] loading processor/tokenizer") | |
| processor = AutoProcessor.from_pretrained(args.base_model, token=args.hf_token or None) | |
| tokenizer = getattr(processor, "tokenizer", None) | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, token=args.hf_token or None) | |
| if getattr(tokenizer, "pad_token_id", None) is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if args.data_loader_smoke_test_only: | |
| print("[data] running dataset-loader smoke test only; model weights will not be loaded") | |
| smoke_stream = load_streaming_mix(args, processor if hasattr(processor, "apply_chat_template") else tokenizer) | |
| print("[data] smoke test: pulling examples from lazy mixed stream", flush=True) | |
| seen = 0 | |
| for row in smoke_stream: | |
| text = row.get("text") if isinstance(row, dict) else None | |
| if not text: | |
| continue | |
| preview = re.sub(r"\s+", " ", text[:300]).strip() | |
| print(json.dumps({"sample": seen, "chars": len(text), "preview": preview}, ensure_ascii=False)) | |
| seen += 1 | |
| if seen >= args.data_loader_smoke_examples: | |
| break | |
| if seen <= 0: | |
| raise RuntimeError("Dataset smoke test produced no usable text samples.") | |
| print(f"[data] dataset-loader smoke test complete: {seen} usable sample(s)") | |
| return | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for model training. Use --data_loader_smoke_test_only to test only the corpus loader on CPU.") | |
| device = torch.device("cuda:0") | |
| if ( | |
| not args.gradient_checkpointing | |
| and torch.cuda.device_count() == 1 | |
| and not args.allow_no_gradient_checkpointing_oom_risk | |
| ): | |
| raise RuntimeError( | |
| "Refusing to run a single-GPU no-gradient-checkpointing configuration. " | |
| "Your B300 OOM log showed the non-checkpointed backward pass using ~254 GiB " | |
| "and failing on an additional 4 GiB allocation. Keep --gradient_checkpointing, " | |
| "or pass --allow_no_gradient_checkpointing_oom_risk if you intentionally want to test it." | |
| ) | |
| print("[init] mutating Gemma 4 config for native additive MoE") | |
| config = mutate_config_for_additive_moe(args) | |
| text_config = get_text_config(config) | |
| print( | |
| json.dumps( | |
| { | |
| "enable_moe_block": text_config.enable_moe_block, | |
| "num_experts": text_config.num_experts, | |
| "top_k_experts": text_config.top_k_experts, | |
| "moe_intermediate_size": text_config.moe_intermediate_size, | |
| "dense_branch_norm_mode": getattr(text_config, "dense_branch_norm_mode", None), | |
| "hidden_size": text_config.hidden_size, | |
| "num_hidden_layers": text_config.num_hidden_layers, | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| print("[init] loading dense checkpoint into modified architecture") | |
| model = load_upcycled_model(args, config) | |
| model.to(device) | |
| assert_gemma4_additive_moe_contract(model, args) | |
| patched_bridges = patch_dense_branch_norms_for_upcycling(model, args) | |
| if patched_bridges: | |
| print(f"[model] Replaced {patched_bridges} dense-side MoE RMSNorms with identity-preserving bridges.") | |
| if args.gradient_checkpointing: | |
| print("[init] enabling gradient checkpointing") | |
| if hasattr(model, "config"): | |
| model.config.use_cache = False | |
| if hasattr(model, "language_model") and hasattr(model.language_model, "config"): | |
| model.language_model.config.use_cache = False | |
| if hasattr(model, "enable_input_require_grads"): | |
| model.enable_input_require_grads() | |
| try: | |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) | |
| except TypeError: | |
| model.gradient_checkpointing_enable() | |
| total_params, trainable_params = freeze_dense_unfreeze_moe(model) | |
| assert_only_moe_parameters_trainable(model) | |
| trainable_dtype_counts = cast_trainable_parameters(model, args) | |
| print("[init] stabilizing fresh MoE branch") | |
| moe_layers = stabilize_new_moe_parameters(model, args) | |
| if moe_layers <= 0: | |
| raise RuntimeError("No MoE decoder layers found after config mutation.") | |
| expected_trainable = expected_moe_trainable_parameters(text_config, args) | |
| tolerance = max(1024, int(0.01 * expected_trainable)) | |
| if abs(trainable_params - expected_trainable) > tolerance: | |
| raise RuntimeError( | |
| f"Trainable parameter count mismatch: got {trainable_params:,}, expected about {expected_trainable:,}. " | |
| "This usually means the Gemma 4 parameter names or MoE implementation changed." | |
| ) | |
| layout_info = verify_moe_parameter_layout(model, args) | |
| print( | |
| json.dumps( | |
| { | |
| "moe_layers": moe_layers, | |
| "total_params": total_params, | |
| "trainable_params": trainable_params, | |
| "expected_trainable_params": expected_trainable, | |
| "trainable_params_billion": round(trainable_params / 1e9, 4), | |
| "frozen_params_billion": round((total_params - trainable_params) / 1e9, 4), | |
| "trainable_param_dtype_counts": trainable_dtype_counts, | |
| **layout_info, | |
| **moe_branch_scale_summary(model), | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| if args.dense_equivalence_check: | |
| print("[preflight] checking initialized upcycled logits against the dense base") | |
| print(json.dumps(dense_equivalence_check(model, tokenizer, args, device), indent=2)) | |
| if args.preflight_forward: | |
| print("[preflight] running one text-only causal-LM forward pass") | |
| print(json.dumps(preflight_forward(model, tokenizer, device), indent=2)) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| optimizer = build_optimizer(model, args) | |
| def lr_lambda(step: int) -> float: | |
| if step < args.warmup_steps: | |
| return max(float(step), 1.0) / max(float(args.warmup_steps), 1.0) | |
| progress = min(1.0, float(step - args.warmup_steps) / max(float(args.max_steps - args.warmup_steps), 1.0)) | |
| cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return args.min_lr_ratio + (1.0 - args.min_lr_ratio) * cosine | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) | |
| print("[data] building streamed mixed corpus", flush=True) | |
| text_stream = load_streaming_mix(args, processor if hasattr(processor, "apply_chat_template") else tokenizer) | |
| print("[data] wrapping lazy mixed stream in token packer", flush=True) | |
| packed = PackedTokenDataset( | |
| text_stream=text_stream, | |
| tokenizer=tokenizer, | |
| seq_len=args.seq_len, | |
| max_tokens_per_sample=args.max_tokens_per_sample, | |
| seed=args.seed, | |
| startup_log_every_examples=args.packer_startup_log_every_examples, | |
| startup_heartbeat_sec=args.data_startup_heartbeat_sec, | |
| ) | |
| print("[data] constructing DataLoader", flush=True) | |
| loader = DataLoader( | |
| packed, | |
| batch_size=args.micro_batch_size, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| print( | |
| "[data] DataLoader constructed; first batch will trigger lazy interleave, " | |
| "streaming shuffle-buffer fill, filtering, tokenization, and packing", | |
| flush=True, | |
| ) | |
| recorder = RouterRecorder() | |
| recorder.attach(model) | |
| model.train() | |
| optimizer.zero_grad(set_to_none=True) | |
| global_step = 0 | |
| micro_step = 0 | |
| tokens_seen = 0 | |
| running_ce = 0.0 | |
| running_aux = 0.0 | |
| running_reg = 0.0 | |
| running_norm_reg = 0.0 | |
| running_entropy = 0.0 | |
| start_time = time.time() | |
| wall_clock_stop_after_sec = 0.0 | |
| if float(args.wall_clock_limit_hours or 0.0) > 0.0: | |
| wall_clock_stop_after_sec = max(1.0, float(args.wall_clock_limit_hours) * 3600.0 - float(args.wall_clock_save_margin_minutes) * 60.0) | |
| print( | |
| f"[budget] wall-clock limit enabled: requested={args.wall_clock_limit_hours:.2f}h " | |
| f"save_margin={args.wall_clock_save_margin_minutes:.1f}min; " | |
| f"will stop training after about {wall_clock_stop_after_sec/3600.0:.2f}h and save final model", | |
| flush=True, | |
| ) | |
| pbar = tqdm(total=args.max_steps, desc="optimizer steps", dynamic_ncols=True) | |
| print("[data] creating DataLoader iterator", flush=True) | |
| data_iter = iter(loader) | |
| print("[data] DataLoader iterator created; fetching first packed microbatch now", flush=True) | |
| first_batch_logged = False | |
| first_batch_fetch_start = time.monotonic() | |
| while global_step < args.max_steps: | |
| try: | |
| if not first_batch_logged: | |
| with WaitHeartbeat( | |
| "[data] still waiting for first packed microbatch " | |
| "(normal causes: streaming shuffle fill, filtering, tokenizer startup, packing)", | |
| args.data_startup_heartbeat_sec, | |
| ): | |
| batch = next(data_iter) | |
| first_batch_logged = True | |
| print( | |
| "[data] first packed microbatch fetched in " | |
| f"{time.monotonic() - first_batch_fetch_start:.1f}s; " | |
| f"shape={tuple(batch['input_ids'].shape)}", | |
| flush=True, | |
| ) | |
| else: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| print("[data] DataLoader exhausted; rebuilding iterator", flush=True) | |
| data_iter = iter(loader) | |
| batch = next(data_iter) | |
| batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} | |
| tokens_seen += int(batch["input_ids"].numel()) | |
| recorder.clear() | |
| recorder.recording = True | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| outputs = model(**batch, return_dict=True) | |
| ce_loss = outputs.loss | |
| if ce_loss is None: | |
| logits = outputs.logits | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = batch["labels"][..., 1:].contiguous() | |
| ce_loss = F.cross_entropy( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| aux = router_load_balancing_loss(recorder.records, args.num_experts) | |
| entropy = router_entropy_metric(recorder.records) | |
| # Avoid capturing duplicate graphs from gradient-checkpointing recompute. | |
| recorder.recording = False | |
| reg = branch_clip_regularizer(model, args) | |
| norm_reg = norm_anchor_regularizer(model, args) | |
| loss = ( | |
| ce_loss | |
| + args.router_aux_coef * aux | |
| + args.branch_clip_reg_coef * reg | |
| + args.norm_anchor_reg_coef * norm_reg | |
| ) | |
| loss = loss / args.grad_accum_steps | |
| loss.backward() | |
| # Clear any checkpoint-recompute hook records immediately to avoid holding graphs. | |
| recorder.clear() | |
| micro_step += 1 | |
| running_ce += float(ce_loss.detach().cpu()) | |
| running_aux += float(aux.detach().cpu()) | |
| running_reg += float(reg.detach().cpu()) | |
| running_norm_reg += float(norm_reg.detach().cpu()) | |
| running_entropy += float(entropy.detach().cpu()) | |
| if micro_step % args.grad_accum_steps == 0: | |
| if args.max_grad_norm > 0: | |
| torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.max_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| clamp_moe_controls(model, args) | |
| global_step += 1 | |
| pbar.update(1) | |
| if global_step % args.log_every == 0: | |
| elapsed = max(time.time() - start_time, 1e-6) | |
| lr_by_group = {f"lr_{g.get('name', f'group_{i}')}": g["lr"] for i, g in enumerate(optimizer.param_groups)} | |
| mem = torch.cuda.max_memory_allocated() / (1024**3) | |
| log = { | |
| "step": global_step, | |
| "ce_loss": running_ce / args.log_every / args.grad_accum_steps, | |
| "router_aux": running_aux / args.log_every / args.grad_accum_steps, | |
| "branch_clip_reg": running_reg / args.log_every / args.grad_accum_steps, | |
| "norm_anchor_reg": running_norm_reg / args.log_every / args.grad_accum_steps, | |
| "router_entropy": running_entropy / args.log_every / args.grad_accum_steps, | |
| "tokens_seen": tokens_seen, | |
| "tokens_per_sec": tokens_seen / elapsed, | |
| "cuda_max_allocated_gib": mem, | |
| **lr_by_group, | |
| **moe_branch_scale_summary(model), | |
| } | |
| print(json.dumps(log, sort_keys=True)) | |
| running_ce = 0.0 | |
| running_aux = 0.0 | |
| running_reg = 0.0 | |
| running_norm_reg = 0.0 | |
| running_entropy = 0.0 | |
| torch.cuda.reset_peak_memory_stats() | |
| if args.save_every > 0 and global_step % args.save_every == 0: | |
| ckpt_dir = Path(args.output_dir) / f"checkpoint-{global_step}" | |
| print(f"[save] {ckpt_dir}") | |
| save_full_model(model, processor, args, ckpt_dir, global_step) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if wall_clock_stop_after_sec > 0.0 and (time.time() - start_time) >= wall_clock_stop_after_sec: | |
| print( | |
| f"[budget] wall-clock training budget reached at step={global_step}; " | |
| "stopping cleanly and saving final model", | |
| flush=True, | |
| ) | |
| break | |
| pbar.close() | |
| final_dir = Path(args.output_dir) | |
| print(f"[save] final model -> {final_dir}") | |
| save_full_model(model, processor, args, final_dir, global_step) | |
| if args.smoke_test: | |
| print("[smoke] generating deterministic test completion") | |
| out = smoke_generate(model, processor, device, args.smoke_test_tokens) | |
| (final_dir / "smoke_generation.txt").write_text(out, encoding="utf-8") | |
| print(out) | |
| recorder.close() | |
| print("[done]") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment