Skip to content

Instantly share code, notes, and snippets.

@AmineDiro
Created March 19, 2026 18:42
Show Gist options
  • Select an option

  • Save AmineDiro/48785832de3d7f9c4e5a9a508399f43f to your computer and use it in GitHub Desktop.

Select an option

Save AmineDiro/48785832de3d7f9c4e5a9a508399f43f to your computer and use it in GitHub Desktop.
sync weights ok
"""
HOW TO RUN
-----------
Step 1 – Start a vLLM server with data-parallel support (replace N with the number of DP shards,
and adjust --tensor-parallel-size / --gpu-memory-utilization as needed):
CUDA_VISIBLE_DEVICES=2,3,4,5 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \
--data-parallel-size 4 \
--tensor-parallel-size 1 \
--weight-transfer-config '{"backend":"nccl"}'
--port 8000 \
--dtype bfloat16
Step 2 – Configure accelerate for FSDP2 (or DDP for a simpler test):
accelerate config # choose FSDP with the FSDP2 option, or multi-GPU / single-GPU DDP
Step 3 – Launch this script (match --num_processes to the number of training GPUs):
VLLM_SERVER_URL=http://localhost:8001 \
accelerate launch --num_processes <num_training_gpus> debug_weight_sync_effective.py
Optional env-vars:
VLLM_SERVER_URL – base URL of the vLLM server (default: http://localhost:8000)
MODEL_NAME – HF model id to load (default: Qwen/Qwen3-4B)
N_SYNCS – number of weight-sync rounds (default: 3)
NOISE_SCALE – std of Gaussian noise added (default: 0.5)
-----------
What this script tests
----------------------
Unlike debug_weight_sync.py (which only verifies that the sync machinery runs without errors),
this script verifies that vLLM's model weights are EFFECTIVELY updated after a sync.
For each round it:
1. Syncs the current model weights to vLLM.
2. Queries vLLM for logprobs on several fixed prompts (baseline).
3. Picks a random linear layer and adds Gaussian noise to its weights.
4. Syncs the modified weights to vLLM.
5. Queries vLLM for logprobs on the same prompts again.
6. Compares the two sets of logprobs — they MUST differ, proving the sync was effective.
"""
import asyncio
import os
import random
import sys
import threading
import time
import requests
import torch
from accelerate import Accelerator
from torch.distributed._tensor import DTensor
from transformers import AutoModelForCausalLM
try:
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
except ImportError as exc:
raise ImportError("vLLM is required for this debug script. Install it with `pip install trl[vllm]`.") from exc
# ---------------------------------------------------------------------------
# Configuration (from environment / defaults)
# ---------------------------------------------------------------------------
VLLM_SERVER_URL: str = os.environ.get("VLLM_SERVER_URL", "http://localhost:8000").rstrip("/")
MODEL_NAME: str = os.environ.get("MODEL_NAME", "Qwen/Qwen3-4B")
N_SYNCS: int = int(os.environ.get("N_SYNCS", "20"))
NOISE_SCALE: float = float(os.environ.get("NOISE_SCALE", "0.5"))
# Fixed prompts used for logprob comparison. We use several to get a robust signal.
TEST_PROMPTS = [
"The capital of France is",
"In quantum mechanics, the uncertainty principle states that",
"def fibonacci(n):\n if n <= 1:\n return",
"Once upon a time, in a land far away, there lived a",
]
# ---------------------------------------------------------------------------
# Minimal replica of AsyncRolloutWorker (weight-sync surface only)
# ---------------------------------------------------------------------------
class DebugRolloutWorker:
"""
Minimal replica of AsyncRolloutWorker that exercises only the weight-sync
path: _init_weight_transfer, pause/resume, and send_weights.
The public interface and internal variable names are kept identical to the
production class so that diffs are easy to read.
"""
def __init__(
self,
vllm_server_url: str,
weight_names: list,
weight_dtype_names: list,
weight_shapes: list,
rank: int,
) -> None:
self.vllm_server_url = vllm_server_url
self.rank = rank
# Mirrors AsyncRolloutWorker.__init__ exactly
self._weight_update_info = {
"names": weight_names,
"dtype_names": weight_dtype_names,
"shapes": weight_shapes,
"packed": True,
"is_checkpoint_format": True,
}
self.model_update_group = None
# Asyncio / threading state (mirrors _run / run)
self._loop: asyncio.AbstractEventLoop | None = None
self._stop_event: asyncio.Event | None = None
# Used by main() to know when async init has completed
self._ready_event: threading.Event = threading.Event()
# ------------------------------------------------------------------
# Thread + asyncio bootstrap (mirrors start / _run / run)
# ------------------------------------------------------------------
def start(self) -> None:
thread = threading.Thread(target=self._run, daemon=True)
thread.start()
def _run(self) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._loop = loop
self._stop_event = asyncio.Event()
try:
loop.run_until_complete(self.run(stop_event=self._stop_event))
except Exception as e:
print(f"[rank {self.rank}] Worker thread failed: {e}", file=sys.stderr)
raise
finally:
loop.close()
async def run(self, stop_event: asyncio.Event | None = None) -> None:
"""Mirrors AsyncRolloutWorker.run — calls _init_weight_transfer via asyncio.to_thread."""
if stop_event is None:
stop_event = asyncio.Event()
# In production this is wrapped in `async with aiohttp.ClientSession()`.
# We skip the session here because there are no generate/score loops.
await asyncio.to_thread(self._init_weight_transfer)
print(f"[rank {self.rank}] Worker: _init_weight_transfer complete.")
# Signal main() that the weight-transfer engine is ready, then idle
# until stop() is called (mirrors the generate/score gather in production).
self._ready_event.set()
await stop_event.wait()
# ------------------------------------------------------------------
# _init_weight_transfer (mirrors AsyncRolloutWorker._init_weight_transfer exactly)
# ------------------------------------------------------------------
def _init_weight_transfer(self) -> None:
print(f"[rank {self.rank}] Init weight sync group with vLLM")
response = requests.get(f"{self.vllm_server_url}/get_world_size")
inference_world_size = response.json()["world_size"]
world_size = inference_world_size + 1
master_address = get_ip()
master_port = get_open_port()
init_info = {
"master_address": master_address,
"master_port": master_port,
"rank_offset": 1,
"world_size": world_size,
}
t_init = threading.Thread(
target=requests.post,
args=(f"{self.vllm_server_url}/init_weight_transfer_engine",),
kwargs={"json": {"init_info": init_info}, "timeout": 120},
)
t_init.start()
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
{
"master_address": master_address,
"master_port": master_port,
"world_size": world_size,
}
)
t_init.join()
print(f"[rank {self.rank}] Init weight sync group with vLLM — done")
# ------------------------------------------------------------------
# pause / resume / send_weights (mirrors AsyncRolloutWorker exactly)
# ------------------------------------------------------------------
def pause(self) -> None:
requests.post(f"{self.vllm_server_url}/pause", params={"mode": "wait"})
def resume(self) -> None:
requests.post(f"{self.vllm_server_url}/resume")
def send_weights(self, iterator) -> None:
if self.model_update_group is None:
return
t_update = threading.Thread(
target=requests.post,
args=(f"{self.vllm_server_url}/update_weights",),
kwargs={"json": {"update_info": self._weight_update_info}, "timeout": 1800},
)
t_update.start()
NCCLWeightTransferEngine.trainer_send_weights(
iterator=iterator,
trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True),
)
t_update.join()
def stop(self) -> None:
if self._loop and self._loop.is_running():
try:
self._loop.call_soon_threadsafe(self._stop_event.set)
except RuntimeError:
pass
# ---------------------------------------------------------------------------
# Helpers that mirror AsyncGRPOTrainer methods
# ---------------------------------------------------------------------------
def _streaming_iter(model):
"""
Mirrors AsyncGRPOTrainer._streaming_iter exactly.
Iterate parameters one at a time. For FSDP2 (DTensor), full_tensor()
all-gathers just this parameter across FSDP ranks, then frees it once
the generator advances — avoiding materialising the full model in memory.
"""
for name, param in model.named_parameters():
name = name.removeprefix("module.") # DDP/FSDP1 wrapping
full = param.full_tensor() if isinstance(param, DTensor) else param.detach()
yield name, full
def _sync_weight(accelerator, model, rollout_worker, rank):
"""
Mirrors AsyncGRPOTrainer._sync_weight exactly, with print() replacing
logger.info() and an explicit rank prefix on every line.
"""
t0 = time.time()
print(f"[rank {rank}] Weight sync: pausing vLLM...")
if accelerator.is_main_process and rollout_worker:
rollout_worker.pause()
t_pause = time.time()
print(f"[rank {rank}] Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...")
accelerator.wait_for_everyone()
t_barrier = time.time()
print(f"[rank {rank}] Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)")
if accelerator.is_main_process and rollout_worker:
rollout_worker.send_weights(_streaming_iter(model))
else:
# Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2.
for _ in _streaming_iter(model):
pass
t_transfer = time.time()
accelerator.wait_for_everyone()
print(f"[rank {rank}] Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)")
if accelerator.is_main_process and rollout_worker:
rollout_worker.resume()
print(f"[rank {rank}] Weight sync: done. Total {time.time() - t0:.1f}s")
# ---------------------------------------------------------------------------
# vLLM logprob querying
# ---------------------------------------------------------------------------
def get_vllm_logprobs(prompt: str) -> dict:
"""
Query vLLM's OpenAI-compatible /v1/completions endpoint for the next-token
logprob distribution on a fixed prompt.
Returns a dict with:
- "top_logprobs": list of (token_str, logprob) tuples (top 5)
- "token": the greedy-decoded token string
- "logprob": the logprob of the greedy token
"""
response = requests.post(
f"{VLLM_SERVER_URL}/v1/completions",
json={
"model": MODEL_NAME,
"prompt": prompt,
"max_tokens": 1,
"temperature": 0, # greedy
"n": 1,
"logprobs": 5, # return top-5 logprobs
},
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data["choices"][0]
logprobs_data = choice["logprobs"]
# OpenAI-compatible format: logprobs.tokens, logprobs.token_logprobs, logprobs.top_logprobs
token = logprobs_data["tokens"][0] if logprobs_data["tokens"] else ""
logprob = logprobs_data["token_logprobs"][0] if logprobs_data["token_logprobs"] else None
top_logprobs = logprobs_data["top_logprobs"][0] if logprobs_data["top_logprobs"] else {}
return {
"token": token,
"logprob": logprob,
"top_logprobs": top_logprobs, # dict of {token_str: logprob}
}
def get_all_logprobs(prompts: list[str]) -> list[dict]:
"""Query vLLM logprobs for each prompt and return a list of results."""
results = []
for prompt in prompts:
results.append(get_vllm_logprobs(prompt))
return results
def print_logprobs(label: str, prompts: list[str], results: list[dict]) -> None:
"""Pretty-print logprob results."""
print(f"\n --- {label} ---")
for prompt, result in zip(prompts, results):
snippet = prompt[:60].replace("\n", "\\n")
top3 = sorted(result["top_logprobs"].items(), key=lambda x: x[1], reverse=True)[:3]
top3_str = " ".join(f"{repr(tok)}={lp:.4f}" for tok, lp in top3)
print(f" prompt='{snippet}'")
print(f" greedy={repr(result['token'])} (logp={result['logprob']:.4f}) top3: {top3_str}")
def compare_logprobs(before: list[dict], after: list[dict], prompts: list[str]) -> bool:
"""
Compare two sets of logprob results. Returns True if they differ (sync was effective).
Prints a detailed comparison.
"""
n_changed = 0
print("\n --- Comparison ---")
for i, (b, a) in enumerate(zip(before, after)):
snippet = prompts[i][:60].replace("\n", "\\n")
before_lp = b["logprob"]
after_lp = a["logprob"]
diff = abs(after_lp - before_lp) if (before_lp is not None and after_lp is not None) else 0.0
changed = diff > 1e-6
status = "CHANGED" if changed else "SAME"
print(
f" [{status}] prompt='{snippet}'"
f" before: {repr(b['token'])} (logp={before_lp:.4f})"
f" after: {repr(a['token'])} (logp={after_lp:.4f})"
f" |diff|={diff:.6f}"
)
if changed:
n_changed += 1
print(f"\n Result: {n_changed}/{len(prompts)} prompts show changed logprobs.")
return n_changed > 0
# ---------------------------------------------------------------------------
# Weight perturbation
# ---------------------------------------------------------------------------
def perturb_random_layer(model, noise_scale: float, rank: int) -> str:
"""
Pick a random linear layer weight from the model and add Gaussian noise in-place.
Returns the name of the perturbed parameter.
For FSDP2 (DTensor), we modify the local shard — the noise will be different
on each rank, but that's fine: the full_tensor() in _streaming_iter will
all-gather the (now noisy) shards before sending to vLLM.
"""
# Collect candidate parameters: only 2D weights (linear layers), skip embeddings
candidates = []
for name, param in model.named_parameters():
clean_name = name.removeprefix("module.")
if param.ndim == 2 and "embed" not in clean_name.lower():
candidates.append((clean_name, param))
if not candidates:
raise RuntimeError("No suitable linear layers found for perturbation!")
# Pick one from the middle layers (more interesting than first/last)
mid_start = len(candidates) // 4
mid_end = 3 * len(candidates) // 4
chosen_name, chosen_param = random.choice(candidates[mid_start:mid_end])
# Add noise in-place
with torch.no_grad():
if isinstance(chosen_param, DTensor):
# For FSDP2: modify the local shard directly
local_tensor = chosen_param._local_tensor
noise = torch.randn_like(local_tensor) * noise_scale
local_tensor.add_(noise)
print(
f"[rank {rank}] Perturbed '{chosen_name}' (DTensor local shard "
f"{list(local_tensor.shape)}, noise_scale={noise_scale})"
)
else:
noise = torch.randn_like(chosen_param) * noise_scale
chosen_param.add_(noise)
print(f"[rank {rank}] Perturbed '{chosen_name}' ({list(chosen_param.shape)}, noise_scale={noise_scale})")
return chosen_name
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
accelerator = Accelerator()
rank = accelerator.process_index
print(f"[rank {rank}] Loading model {MODEL_NAME} ...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map=None, torch_dtype=torch.bfloat16)
model = accelerator.prepare(model)
print(f"[rank {rank}] Model loaded and prepared.")
# ------------------------------------------------------------------
# Build weight metadata on rank 0 (mirrors AsyncGRPOTrainer.__init__)
# ------------------------------------------------------------------
rollout_worker = None
if accelerator.is_main_process:
# Collect weight metadata once — names/dtypes/shapes are fixed for the lifetime of training.
# DTensor.shape returns the global shape without triggering any all-gather.
weight_names, weight_dtype_names, weight_shapes = [], [], []
for name, param in model.named_parameters():
name = name.removeprefix("module.") # DDP/FSDP1 wrapping
weight_names.append(name)
weight_dtype_names.append(str(param.dtype).split(".")[-1])
weight_shapes.append(list(param.shape))
rollout_worker = DebugRolloutWorker(
vllm_server_url=VLLM_SERVER_URL,
weight_names=weight_names,
weight_dtype_names=weight_dtype_names,
weight_shapes=weight_shapes,
rank=rank,
)
print(f"[rank {rank}] Starting rollout worker thread (mirrors AsyncRolloutWorker.start)...")
rollout_worker.start()
# Block until _init_weight_transfer has completed inside the daemon thread
print(f"[rank {rank}] Waiting for _init_weight_transfer to complete...")
rollout_worker._ready_event.wait()
print(f"[rank {rank}] Weight transfer engine initialised.")
# All ranks synchronise before entering the sync loop
accelerator.wait_for_everyone()
# ------------------------------------------------------------------
# Test loop: for each round, verify that weight changes propagate
# ------------------------------------------------------------------
n_passed = 0
n_failed = 0
for sync_idx in range(1, N_SYNCS + 1):
print(f"\n{'=' * 70}")
print(f"[rank {rank}] ===== Round {sync_idx}/{N_SYNCS} =====")
print(f"{'=' * 70}")
# Step 1: Sync current weights to vLLM (baseline)
print(f"\n[rank {rank}] Step 1: Syncing current weights to vLLM (baseline)...")
_sync_weight(accelerator, model, rollout_worker, rank)
# Step 2: Get baseline logprobs from vLLM (rank 0 only)
before_logprobs = None
if accelerator.is_main_process:
print(f"\n[rank {rank}] Step 2: Querying vLLM for baseline logprobs...")
before_logprobs = get_all_logprobs(TEST_PROMPTS)
print_logprobs("BEFORE perturbation", TEST_PROMPTS, before_logprobs)
accelerator.wait_for_everyone()
# Step 3: Perturb a random layer on all ranks
print(f"\n[rank {rank}] Step 3: Perturbing a random layer (noise_scale={NOISE_SCALE})...")
# Use the same random seed across ranks so they pick the same layer
random.seed(42 + sync_idx)
perturbed_name = perturb_random_layer(model, NOISE_SCALE, rank)
accelerator.wait_for_everyone()
# Step 4: Sync modified weights to vLLM
print(f"\n[rank {rank}] Step 4: Syncing MODIFIED weights to vLLM...")
_sync_weight(accelerator, model, rollout_worker, rank)
# Step 5: Get new logprobs from vLLM (rank 0 only)
if accelerator.is_main_process:
print(f"\n[rank {rank}] Step 5: Querying vLLM for post-perturbation logprobs...")
after_logprobs = get_all_logprobs(TEST_PROMPTS)
print_logprobs("AFTER perturbation", TEST_PROMPTS, after_logprobs)
# Step 6: Compare
print(f"\n[rank {rank}] Step 6: Comparing logprobs...")
any_changed = compare_logprobs(before_logprobs, after_logprobs, TEST_PROMPTS)
if any_changed:
print(
f"\n >>> PASS: Round {sync_idx} — logprobs changed after weight sync. "
f"Layer '{perturbed_name}' perturbation was effective."
)
n_passed += 1
else:
print(
f"\n >>> FAIL: Round {sync_idx} — logprobs did NOT change after weight sync! "
f"Layer '{perturbed_name}' perturbation was NOT reflected in vLLM.",
file=sys.stderr,
)
n_failed += 1
accelerator.wait_for_everyone()
# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
if accelerator.is_main_process:
print(f"\n{'=' * 70}")
print(f" SUMMARY: {n_passed} passed, {n_failed} failed out of {N_SYNCS} rounds")
if n_failed == 0:
print(" All rounds PASSED — weight syncs are effective!")
else:
print(f" WARNING: {n_failed} round(s) FAILED — weights may not be syncing correctly!")
print(f"{'=' * 70}")
# ------------------------------------------------------------------
# Teardown
# ------------------------------------------------------------------
accelerator.wait_for_everyone()
if accelerator.is_main_process and rollout_worker:
print(f"[rank {rank}] Stopping rollout worker...")
rollout_worker.stop()
print(f"[rank {rank}] Done.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment