Created
March 19, 2026 18:42
-
-
Save AmineDiro/48785832de3d7f9c4e5a9a508399f43f to your computer and use it in GitHub Desktop.
sync weights ok
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| HOW TO RUN | |
| ----------- | |
| Step 1 – Start a vLLM server with data-parallel support (replace N with the number of DP shards, | |
| and adjust --tensor-parallel-size / --gpu-memory-utilization as needed): | |
| CUDA_VISIBLE_DEVICES=2,3,4,5 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \ | |
| --data-parallel-size 4 \ | |
| --tensor-parallel-size 1 \ | |
| --weight-transfer-config '{"backend":"nccl"}' | |
| --port 8000 \ | |
| --dtype bfloat16 | |
| Step 2 – Configure accelerate for FSDP2 (or DDP for a simpler test): | |
| accelerate config # choose FSDP with the FSDP2 option, or multi-GPU / single-GPU DDP | |
| Step 3 – Launch this script (match --num_processes to the number of training GPUs): | |
| VLLM_SERVER_URL=http://localhost:8001 \ | |
| accelerate launch --num_processes <num_training_gpus> debug_weight_sync_effective.py | |
| Optional env-vars: | |
| VLLM_SERVER_URL – base URL of the vLLM server (default: http://localhost:8000) | |
| MODEL_NAME – HF model id to load (default: Qwen/Qwen3-4B) | |
| N_SYNCS – number of weight-sync rounds (default: 3) | |
| NOISE_SCALE – std of Gaussian noise added (default: 0.5) | |
| ----------- | |
| What this script tests | |
| ---------------------- | |
| Unlike debug_weight_sync.py (which only verifies that the sync machinery runs without errors), | |
| this script verifies that vLLM's model weights are EFFECTIVELY updated after a sync. | |
| For each round it: | |
| 1. Syncs the current model weights to vLLM. | |
| 2. Queries vLLM for logprobs on several fixed prompts (baseline). | |
| 3. Picks a random linear layer and adds Gaussian noise to its weights. | |
| 4. Syncs the modified weights to vLLM. | |
| 5. Queries vLLM for logprobs on the same prompts again. | |
| 6. Compares the two sets of logprobs — they MUST differ, proving the sync was effective. | |
| """ | |
| import asyncio | |
| import os | |
| import random | |
| import sys | |
| import threading | |
| import time | |
| import requests | |
| import torch | |
| from accelerate import Accelerator | |
| from torch.distributed._tensor import DTensor | |
| from transformers import AutoModelForCausalLM | |
| try: | |
| from vllm.distributed.weight_transfer.nccl_engine import ( | |
| NCCLTrainerSendWeightsArgs, | |
| NCCLWeightTransferEngine, | |
| ) | |
| from vllm.utils.network_utils import get_ip, get_open_port | |
| except ImportError as exc: | |
| raise ImportError("vLLM is required for this debug script. Install it with `pip install trl[vllm]`.") from exc | |
| # --------------------------------------------------------------------------- | |
| # Configuration (from environment / defaults) | |
| # --------------------------------------------------------------------------- | |
| VLLM_SERVER_URL: str = os.environ.get("VLLM_SERVER_URL", "http://localhost:8000").rstrip("/") | |
| MODEL_NAME: str = os.environ.get("MODEL_NAME", "Qwen/Qwen3-4B") | |
| N_SYNCS: int = int(os.environ.get("N_SYNCS", "20")) | |
| NOISE_SCALE: float = float(os.environ.get("NOISE_SCALE", "0.5")) | |
| # Fixed prompts used for logprob comparison. We use several to get a robust signal. | |
| TEST_PROMPTS = [ | |
| "The capital of France is", | |
| "In quantum mechanics, the uncertainty principle states that", | |
| "def fibonacci(n):\n if n <= 1:\n return", | |
| "Once upon a time, in a land far away, there lived a", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Minimal replica of AsyncRolloutWorker (weight-sync surface only) | |
| # --------------------------------------------------------------------------- | |
| class DebugRolloutWorker: | |
| """ | |
| Minimal replica of AsyncRolloutWorker that exercises only the weight-sync | |
| path: _init_weight_transfer, pause/resume, and send_weights. | |
| The public interface and internal variable names are kept identical to the | |
| production class so that diffs are easy to read. | |
| """ | |
| def __init__( | |
| self, | |
| vllm_server_url: str, | |
| weight_names: list, | |
| weight_dtype_names: list, | |
| weight_shapes: list, | |
| rank: int, | |
| ) -> None: | |
| self.vllm_server_url = vllm_server_url | |
| self.rank = rank | |
| # Mirrors AsyncRolloutWorker.__init__ exactly | |
| self._weight_update_info = { | |
| "names": weight_names, | |
| "dtype_names": weight_dtype_names, | |
| "shapes": weight_shapes, | |
| "packed": True, | |
| "is_checkpoint_format": True, | |
| } | |
| self.model_update_group = None | |
| # Asyncio / threading state (mirrors _run / run) | |
| self._loop: asyncio.AbstractEventLoop | None = None | |
| self._stop_event: asyncio.Event | None = None | |
| # Used by main() to know when async init has completed | |
| self._ready_event: threading.Event = threading.Event() | |
| # ------------------------------------------------------------------ | |
| # Thread + asyncio bootstrap (mirrors start / _run / run) | |
| # ------------------------------------------------------------------ | |
| def start(self) -> None: | |
| thread = threading.Thread(target=self._run, daemon=True) | |
| thread.start() | |
| def _run(self) -> None: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| self._loop = loop | |
| self._stop_event = asyncio.Event() | |
| try: | |
| loop.run_until_complete(self.run(stop_event=self._stop_event)) | |
| except Exception as e: | |
| print(f"[rank {self.rank}] Worker thread failed: {e}", file=sys.stderr) | |
| raise | |
| finally: | |
| loop.close() | |
| async def run(self, stop_event: asyncio.Event | None = None) -> None: | |
| """Mirrors AsyncRolloutWorker.run — calls _init_weight_transfer via asyncio.to_thread.""" | |
| if stop_event is None: | |
| stop_event = asyncio.Event() | |
| # In production this is wrapped in `async with aiohttp.ClientSession()`. | |
| # We skip the session here because there are no generate/score loops. | |
| await asyncio.to_thread(self._init_weight_transfer) | |
| print(f"[rank {self.rank}] Worker: _init_weight_transfer complete.") | |
| # Signal main() that the weight-transfer engine is ready, then idle | |
| # until stop() is called (mirrors the generate/score gather in production). | |
| self._ready_event.set() | |
| await stop_event.wait() | |
| # ------------------------------------------------------------------ | |
| # _init_weight_transfer (mirrors AsyncRolloutWorker._init_weight_transfer exactly) | |
| # ------------------------------------------------------------------ | |
| def _init_weight_transfer(self) -> None: | |
| print(f"[rank {self.rank}] Init weight sync group with vLLM") | |
| response = requests.get(f"{self.vllm_server_url}/get_world_size") | |
| inference_world_size = response.json()["world_size"] | |
| world_size = inference_world_size + 1 | |
| master_address = get_ip() | |
| master_port = get_open_port() | |
| init_info = { | |
| "master_address": master_address, | |
| "master_port": master_port, | |
| "rank_offset": 1, | |
| "world_size": world_size, | |
| } | |
| t_init = threading.Thread( | |
| target=requests.post, | |
| args=(f"{self.vllm_server_url}/init_weight_transfer_engine",), | |
| kwargs={"json": {"init_info": init_info}, "timeout": 120}, | |
| ) | |
| t_init.start() | |
| self.model_update_group = NCCLWeightTransferEngine.trainer_init( | |
| { | |
| "master_address": master_address, | |
| "master_port": master_port, | |
| "world_size": world_size, | |
| } | |
| ) | |
| t_init.join() | |
| print(f"[rank {self.rank}] Init weight sync group with vLLM — done") | |
| # ------------------------------------------------------------------ | |
| # pause / resume / send_weights (mirrors AsyncRolloutWorker exactly) | |
| # ------------------------------------------------------------------ | |
| def pause(self) -> None: | |
| requests.post(f"{self.vllm_server_url}/pause", params={"mode": "wait"}) | |
| def resume(self) -> None: | |
| requests.post(f"{self.vllm_server_url}/resume") | |
| def send_weights(self, iterator) -> None: | |
| if self.model_update_group is None: | |
| return | |
| t_update = threading.Thread( | |
| target=requests.post, | |
| args=(f"{self.vllm_server_url}/update_weights",), | |
| kwargs={"json": {"update_info": self._weight_update_info}, "timeout": 1800}, | |
| ) | |
| t_update.start() | |
| NCCLWeightTransferEngine.trainer_send_weights( | |
| iterator=iterator, | |
| trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True), | |
| ) | |
| t_update.join() | |
| def stop(self) -> None: | |
| if self._loop and self._loop.is_running(): | |
| try: | |
| self._loop.call_soon_threadsafe(self._stop_event.set) | |
| except RuntimeError: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Helpers that mirror AsyncGRPOTrainer methods | |
| # --------------------------------------------------------------------------- | |
| def _streaming_iter(model): | |
| """ | |
| Mirrors AsyncGRPOTrainer._streaming_iter exactly. | |
| Iterate parameters one at a time. For FSDP2 (DTensor), full_tensor() | |
| all-gathers just this parameter across FSDP ranks, then frees it once | |
| the generator advances — avoiding materialising the full model in memory. | |
| """ | |
| for name, param in model.named_parameters(): | |
| name = name.removeprefix("module.") # DDP/FSDP1 wrapping | |
| full = param.full_tensor() if isinstance(param, DTensor) else param.detach() | |
| yield name, full | |
| def _sync_weight(accelerator, model, rollout_worker, rank): | |
| """ | |
| Mirrors AsyncGRPOTrainer._sync_weight exactly, with print() replacing | |
| logger.info() and an explicit rank prefix on every line. | |
| """ | |
| t0 = time.time() | |
| print(f"[rank {rank}] Weight sync: pausing vLLM...") | |
| if accelerator.is_main_process and rollout_worker: | |
| rollout_worker.pause() | |
| t_pause = time.time() | |
| print(f"[rank {rank}] Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...") | |
| accelerator.wait_for_everyone() | |
| t_barrier = time.time() | |
| print(f"[rank {rank}] Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") | |
| if accelerator.is_main_process and rollout_worker: | |
| rollout_worker.send_weights(_streaming_iter(model)) | |
| else: | |
| # Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2. | |
| for _ in _streaming_iter(model): | |
| pass | |
| t_transfer = time.time() | |
| accelerator.wait_for_everyone() | |
| print(f"[rank {rank}] Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)") | |
| if accelerator.is_main_process and rollout_worker: | |
| rollout_worker.resume() | |
| print(f"[rank {rank}] Weight sync: done. Total {time.time() - t0:.1f}s") | |
| # --------------------------------------------------------------------------- | |
| # vLLM logprob querying | |
| # --------------------------------------------------------------------------- | |
| def get_vllm_logprobs(prompt: str) -> dict: | |
| """ | |
| Query vLLM's OpenAI-compatible /v1/completions endpoint for the next-token | |
| logprob distribution on a fixed prompt. | |
| Returns a dict with: | |
| - "top_logprobs": list of (token_str, logprob) tuples (top 5) | |
| - "token": the greedy-decoded token string | |
| - "logprob": the logprob of the greedy token | |
| """ | |
| response = requests.post( | |
| f"{VLLM_SERVER_URL}/v1/completions", | |
| json={ | |
| "model": MODEL_NAME, | |
| "prompt": prompt, | |
| "max_tokens": 1, | |
| "temperature": 0, # greedy | |
| "n": 1, | |
| "logprobs": 5, # return top-5 logprobs | |
| }, | |
| timeout=120, | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| choice = data["choices"][0] | |
| logprobs_data = choice["logprobs"] | |
| # OpenAI-compatible format: logprobs.tokens, logprobs.token_logprobs, logprobs.top_logprobs | |
| token = logprobs_data["tokens"][0] if logprobs_data["tokens"] else "" | |
| logprob = logprobs_data["token_logprobs"][0] if logprobs_data["token_logprobs"] else None | |
| top_logprobs = logprobs_data["top_logprobs"][0] if logprobs_data["top_logprobs"] else {} | |
| return { | |
| "token": token, | |
| "logprob": logprob, | |
| "top_logprobs": top_logprobs, # dict of {token_str: logprob} | |
| } | |
| def get_all_logprobs(prompts: list[str]) -> list[dict]: | |
| """Query vLLM logprobs for each prompt and return a list of results.""" | |
| results = [] | |
| for prompt in prompts: | |
| results.append(get_vllm_logprobs(prompt)) | |
| return results | |
| def print_logprobs(label: str, prompts: list[str], results: list[dict]) -> None: | |
| """Pretty-print logprob results.""" | |
| print(f"\n --- {label} ---") | |
| for prompt, result in zip(prompts, results): | |
| snippet = prompt[:60].replace("\n", "\\n") | |
| top3 = sorted(result["top_logprobs"].items(), key=lambda x: x[1], reverse=True)[:3] | |
| top3_str = " ".join(f"{repr(tok)}={lp:.4f}" for tok, lp in top3) | |
| print(f" prompt='{snippet}'") | |
| print(f" greedy={repr(result['token'])} (logp={result['logprob']:.4f}) top3: {top3_str}") | |
| def compare_logprobs(before: list[dict], after: list[dict], prompts: list[str]) -> bool: | |
| """ | |
| Compare two sets of logprob results. Returns True if they differ (sync was effective). | |
| Prints a detailed comparison. | |
| """ | |
| n_changed = 0 | |
| print("\n --- Comparison ---") | |
| for i, (b, a) in enumerate(zip(before, after)): | |
| snippet = prompts[i][:60].replace("\n", "\\n") | |
| before_lp = b["logprob"] | |
| after_lp = a["logprob"] | |
| diff = abs(after_lp - before_lp) if (before_lp is not None and after_lp is not None) else 0.0 | |
| changed = diff > 1e-6 | |
| status = "CHANGED" if changed else "SAME" | |
| print( | |
| f" [{status}] prompt='{snippet}'" | |
| f" before: {repr(b['token'])} (logp={before_lp:.4f})" | |
| f" after: {repr(a['token'])} (logp={after_lp:.4f})" | |
| f" |diff|={diff:.6f}" | |
| ) | |
| if changed: | |
| n_changed += 1 | |
| print(f"\n Result: {n_changed}/{len(prompts)} prompts show changed logprobs.") | |
| return n_changed > 0 | |
| # --------------------------------------------------------------------------- | |
| # Weight perturbation | |
| # --------------------------------------------------------------------------- | |
| def perturb_random_layer(model, noise_scale: float, rank: int) -> str: | |
| """ | |
| Pick a random linear layer weight from the model and add Gaussian noise in-place. | |
| Returns the name of the perturbed parameter. | |
| For FSDP2 (DTensor), we modify the local shard — the noise will be different | |
| on each rank, but that's fine: the full_tensor() in _streaming_iter will | |
| all-gather the (now noisy) shards before sending to vLLM. | |
| """ | |
| # Collect candidate parameters: only 2D weights (linear layers), skip embeddings | |
| candidates = [] | |
| for name, param in model.named_parameters(): | |
| clean_name = name.removeprefix("module.") | |
| if param.ndim == 2 and "embed" not in clean_name.lower(): | |
| candidates.append((clean_name, param)) | |
| if not candidates: | |
| raise RuntimeError("No suitable linear layers found for perturbation!") | |
| # Pick one from the middle layers (more interesting than first/last) | |
| mid_start = len(candidates) // 4 | |
| mid_end = 3 * len(candidates) // 4 | |
| chosen_name, chosen_param = random.choice(candidates[mid_start:mid_end]) | |
| # Add noise in-place | |
| with torch.no_grad(): | |
| if isinstance(chosen_param, DTensor): | |
| # For FSDP2: modify the local shard directly | |
| local_tensor = chosen_param._local_tensor | |
| noise = torch.randn_like(local_tensor) * noise_scale | |
| local_tensor.add_(noise) | |
| print( | |
| f"[rank {rank}] Perturbed '{chosen_name}' (DTensor local shard " | |
| f"{list(local_tensor.shape)}, noise_scale={noise_scale})" | |
| ) | |
| else: | |
| noise = torch.randn_like(chosen_param) * noise_scale | |
| chosen_param.add_(noise) | |
| print(f"[rank {rank}] Perturbed '{chosen_name}' ({list(chosen_param.shape)}, noise_scale={noise_scale})") | |
| return chosen_name | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| accelerator = Accelerator() | |
| rank = accelerator.process_index | |
| print(f"[rank {rank}] Loading model {MODEL_NAME} ...") | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map=None, torch_dtype=torch.bfloat16) | |
| model = accelerator.prepare(model) | |
| print(f"[rank {rank}] Model loaded and prepared.") | |
| # ------------------------------------------------------------------ | |
| # Build weight metadata on rank 0 (mirrors AsyncGRPOTrainer.__init__) | |
| # ------------------------------------------------------------------ | |
| rollout_worker = None | |
| if accelerator.is_main_process: | |
| # Collect weight metadata once — names/dtypes/shapes are fixed for the lifetime of training. | |
| # DTensor.shape returns the global shape without triggering any all-gather. | |
| weight_names, weight_dtype_names, weight_shapes = [], [], [] | |
| for name, param in model.named_parameters(): | |
| name = name.removeprefix("module.") # DDP/FSDP1 wrapping | |
| weight_names.append(name) | |
| weight_dtype_names.append(str(param.dtype).split(".")[-1]) | |
| weight_shapes.append(list(param.shape)) | |
| rollout_worker = DebugRolloutWorker( | |
| vllm_server_url=VLLM_SERVER_URL, | |
| weight_names=weight_names, | |
| weight_dtype_names=weight_dtype_names, | |
| weight_shapes=weight_shapes, | |
| rank=rank, | |
| ) | |
| print(f"[rank {rank}] Starting rollout worker thread (mirrors AsyncRolloutWorker.start)...") | |
| rollout_worker.start() | |
| # Block until _init_weight_transfer has completed inside the daemon thread | |
| print(f"[rank {rank}] Waiting for _init_weight_transfer to complete...") | |
| rollout_worker._ready_event.wait() | |
| print(f"[rank {rank}] Weight transfer engine initialised.") | |
| # All ranks synchronise before entering the sync loop | |
| accelerator.wait_for_everyone() | |
| # ------------------------------------------------------------------ | |
| # Test loop: for each round, verify that weight changes propagate | |
| # ------------------------------------------------------------------ | |
| n_passed = 0 | |
| n_failed = 0 | |
| for sync_idx in range(1, N_SYNCS + 1): | |
| print(f"\n{'=' * 70}") | |
| print(f"[rank {rank}] ===== Round {sync_idx}/{N_SYNCS} =====") | |
| print(f"{'=' * 70}") | |
| # Step 1: Sync current weights to vLLM (baseline) | |
| print(f"\n[rank {rank}] Step 1: Syncing current weights to vLLM (baseline)...") | |
| _sync_weight(accelerator, model, rollout_worker, rank) | |
| # Step 2: Get baseline logprobs from vLLM (rank 0 only) | |
| before_logprobs = None | |
| if accelerator.is_main_process: | |
| print(f"\n[rank {rank}] Step 2: Querying vLLM for baseline logprobs...") | |
| before_logprobs = get_all_logprobs(TEST_PROMPTS) | |
| print_logprobs("BEFORE perturbation", TEST_PROMPTS, before_logprobs) | |
| accelerator.wait_for_everyone() | |
| # Step 3: Perturb a random layer on all ranks | |
| print(f"\n[rank {rank}] Step 3: Perturbing a random layer (noise_scale={NOISE_SCALE})...") | |
| # Use the same random seed across ranks so they pick the same layer | |
| random.seed(42 + sync_idx) | |
| perturbed_name = perturb_random_layer(model, NOISE_SCALE, rank) | |
| accelerator.wait_for_everyone() | |
| # Step 4: Sync modified weights to vLLM | |
| print(f"\n[rank {rank}] Step 4: Syncing MODIFIED weights to vLLM...") | |
| _sync_weight(accelerator, model, rollout_worker, rank) | |
| # Step 5: Get new logprobs from vLLM (rank 0 only) | |
| if accelerator.is_main_process: | |
| print(f"\n[rank {rank}] Step 5: Querying vLLM for post-perturbation logprobs...") | |
| after_logprobs = get_all_logprobs(TEST_PROMPTS) | |
| print_logprobs("AFTER perturbation", TEST_PROMPTS, after_logprobs) | |
| # Step 6: Compare | |
| print(f"\n[rank {rank}] Step 6: Comparing logprobs...") | |
| any_changed = compare_logprobs(before_logprobs, after_logprobs, TEST_PROMPTS) | |
| if any_changed: | |
| print( | |
| f"\n >>> PASS: Round {sync_idx} — logprobs changed after weight sync. " | |
| f"Layer '{perturbed_name}' perturbation was effective." | |
| ) | |
| n_passed += 1 | |
| else: | |
| print( | |
| f"\n >>> FAIL: Round {sync_idx} — logprobs did NOT change after weight sync! " | |
| f"Layer '{perturbed_name}' perturbation was NOT reflected in vLLM.", | |
| file=sys.stderr, | |
| ) | |
| n_failed += 1 | |
| accelerator.wait_for_everyone() | |
| # ------------------------------------------------------------------ | |
| # Summary | |
| # ------------------------------------------------------------------ | |
| if accelerator.is_main_process: | |
| print(f"\n{'=' * 70}") | |
| print(f" SUMMARY: {n_passed} passed, {n_failed} failed out of {N_SYNCS} rounds") | |
| if n_failed == 0: | |
| print(" All rounds PASSED — weight syncs are effective!") | |
| else: | |
| print(f" WARNING: {n_failed} round(s) FAILED — weights may not be syncing correctly!") | |
| print(f"{'=' * 70}") | |
| # ------------------------------------------------------------------ | |
| # Teardown | |
| # ------------------------------------------------------------------ | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process and rollout_worker: | |
| print(f"[rank {rank}] Stopping rollout worker...") | |
| rollout_worker.stop() | |
| print(f"[rank {rank}] Done.") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment