Skip to content

Instantly share code, notes, and snippets.

@AmineDiro
Last active May 29, 2026 08:59
Show Gist options
  • Select an option

  • Save AmineDiro/5e324f609989d3f0437ea13e69ea489f to your computer and use it in GitHub Desktop.

Select an option

Save AmineDiro/5e324f609989d3f0437ea13e69ea489f to your computer and use it in GitHub Desktop.
Async grpo minimal
#!/usr/bin/env -S uv run
# /// script
# requires-python = ">=3.10, <3.13"
# dependencies = [
# "aiohttp==3.13.3",
# "datasets==4.6.1",
# "hf-transfer==0.1.9",
# "requests==2.32.5",
# "torch==2.10.0",
# "transformers==4.57.6",
# "uvicorn==0.41.0",
# "uvloop==0.22.1",
# "vllm==0.17.0",
# ]
# ///
import asyncio
import os
import queue
import subprocess
import sys
import threading
import time
from dataclasses import dataclass
from typing import Tuple, TypedDict
import aiohttp
import requests
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
# TODO: depth should be dynamically computed using global_batch_size, admission_control, max_inflight
ROLLOUT_QUEUE = queue.Queue(maxsize=1024)
@dataclass
class TrainingConfig:
model_name_or_path: str = "Qwen/Qwen3-4B"
num_steps: int = 100
learning_rate: float = 1e-5
global_batch_size: int = 2
gradient_accumulation_steps: int = 4
train_device: str = "cuda:0"
# GRPO specific
group_size: int = 4
max_inflight: int = 32
max_staleness: int = 3
# vllm inference config
vllm_server_url: str = "http://localhost:8000/v1/completions"
vllm_port: int = 8000
temperature: float = 0.8
max_tokens: int = 1024
@dataclass(slots=True)
class RolloutSample:
# TODO: Add some notion of turns, trajectory ...
prompt: list[dict[str, str]]
prompt_ids: list[int]
completion: list[dict[str, str]]
completion_ids: list[int]
advantage: float
old_log_probs: list[float]
model_version: int
# TODO: add metadata
# metrics : dict[str,float]
## generation_time_s: float
class Batch(TypedDict):
prompt_completion_ids: torch.Tensor
attention_mask: torch.Tensor
completion_mask: torch.Tensor
advantages: torch.Tensor
old_log_probs: torch.Tensor
model_versions: torch.Tensor
def run_vllm_server(model_name: str, port: int, gpu_id: int) -> subprocess.Popen:
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
env["VLLM_LOGGING_LEVEL"] = "WARNING"
env["VLLM_SERVER_DEV_MODE"] = "1"
# TODO: should probably have more control using `AsyncLLMEngine`
cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
model_name,
"--port",
str(port),
"--host",
"0.0.0.0",
"--gpu-memory-utilization",
"0.9",
"--max-model-len",
"4096",
"--tensor-parallel-size",
"1",
"--weight-transfer-config",
'{"backend":"nccl"}',
]
log_file = open("vllm_server.log", "w")
print(f"[Server] Launching vLLM server on GPU {gpu_id}, port {port}")
print(f"[Server] Command: {' '.join(cmd)}")
print("[Server] Logs are being written to vllm_server.log")
process = subprocess.Popen(
cmd,
env=env,
stdout=log_file,
stderr=log_file,
)
process._log_file = log_file # type: ignore
return process
class AsyncGRPODataLoader:
def __init__(self, config: TrainingConfig, tokenizer: AutoTokenizer, dataset):
self.config = config
self.tokenizer = tokenizer
self.dataset = dataset
self.dataset_iter = iter(dataset)
# TODO: should be event
self.running = False
self.thread = None
self.loop = None
self.model_version = 0
def update_model_version(self, version: int):
self.model_version = version
def start(self):
self.running = True
self.thread = threading.Thread(target=self._run_event_loop, daemon=True)
self.thread.start()
def _run_event_loop(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
try:
self.loop.run_until_complete(self._generation_loop())
finally:
self.loop.close()
def _get_next_prompt_data(self) -> Tuple[str, list[int]]:
try:
sample = next(self.dataset_iter)
except StopIteration:
self.dataset_iter = iter(self.dataset)
sample = next(self.dataset_iter)
prompt_text = sample.get("problem", "")
if not prompt_text and "prompt" in sample:
prompt_text = sample["prompt"]
if self.tokenizer.chat_template:
messages = [{"role": "user", "content": prompt_text}]
prompt_str = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
prompt_str = prompt_text
prompt_ids = self.tokenizer.encode(prompt_str, add_special_tokens=False)
return prompt_text, prompt_ids
async def _generate_with_group_id(
self, session, group_id: int, prompt_ids: list[int]
):
result = await self._generate_one(session, prompt_ids)
return group_id, result
def _process_completed_group(self, group_data: dict):
prompt_text = group_data["prompt_text"]
prompt_ids = group_data["prompt_ids"]
version_at_generation = group_data["version"]
results = group_data["results"]
# Compute Rewards
rewards = [compute_reward(prompt_text, r[0]) for r in results]
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
# Compute Advantages
mean_r = rewards_tensor.mean()
std_r = rewards_tensor.std()
if std_r.item() < 1e-6:
advantages = torch.zeros_like(rewards_tensor)
else:
advantages = (rewards_tensor - mean_r) / (std_r + 1e-4)
for i, (
completion_text,
completion_ids,
completion_logprobs,
) in enumerate(results):
if not completion_ids:
continue # skip failed generations
rollout = RolloutSample(
prompt=[{"role": "user", "content": prompt_text}],
prompt_ids=prompt_ids,
completion=[{"role": "assistant", "content": completion_text}],
completion_ids=completion_ids,
advantage=advantages[i].item(),
old_log_probs=completion_logprobs,
model_version=version_at_generation,
)
try:
ROLLOUT_QUEUE.put_nowait(rollout)
print(
f"[Rollout] prompt_len={len(prompt_ids)}, "
f"completion_len={len(completion_ids)}, "
f"total_len={len(prompt_ids) + len(completion_ids)}, "
f"reward={rewards[i]:.4f}, "
f"advantage={advantages[i].item():.4f}, "
f"version={version_at_generation}"
)
except queue.Full:
pass
async def _generation_loop(self):
pending_tasks = set()
groups = {}
group_id_counter = 0
async with aiohttp.ClientSession() as session:
while self.running:
while (
not ROLLOUT_QUEUE.full()
and len(pending_tasks) + self.config.group_size
<= self.config.max_inflight
):
prompt_text, prompt_ids = self._get_next_prompt_data()
group_id = group_id_counter
group_id_counter += 1
groups[group_id] = {
"prompt_text": prompt_text,
"prompt_ids": prompt_ids,
"version": self.model_version,
"results": [],
}
for _ in range(self.config.group_size):
task = asyncio.create_task(
self._generate_with_group_id(session, group_id, prompt_ids)
)
pending_tasks.add(task)
if not pending_tasks:
if ROLLOUT_QUEUE.full():
await asyncio.sleep(0.1)
continue
done, pending_tasks = await asyncio.wait(
pending_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
group_id, result = task.result()
groups[group_id]["results"].append(result)
if len(groups[group_id]["results"]) == self.config.group_size:
group_data = groups.pop(group_id)
self._process_completed_group(group_data)
async def _generate_one(
self, session, prompt_ids: list[int]
) -> Tuple[str, list[int], list[float]]:
payload = {
"model": self.config.model_name_or_path,
"prompt": prompt_ids,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
"n": 1,
"return_token_ids": True,
"logprobs": 1,
}
try:
t0 = time.time()
async with session.post(self.config.vllm_server_url, json=payload) as resp:
if resp.status == 200:
data = await resp.json()
choice = data["choices"][0]
text = choice["text"]
completion_ids = choice["token_ids"]
completion_logprobs = choice["logprobs"]["token_logprobs"]
elapsed = time.time() - t0
print(
f"[Generation] vLLM response in {elapsed:.2f}s, "
f"output_chars={len(text)}"
)
return text, completion_ids, completion_logprobs
else:
error_text = await resp.text()
print(
f"[Generation] vLLM error status={resp.status}: {error_text[:200]}"
)
return "", [], []
except Exception as e:
print(f"[Generation] vLLM request failed: {e}")
return "", [], []
def compute_reward(prompt: str, completion: str) -> float:
# Dummy reward: length
return float(len(completion)) / 100.0
def build_batch(
batch: list[RolloutSample], device: torch.device, pad_token_id: int
) -> Batch:
prompt_completion_ids = []
attention_mask = []
completion_mask = []
advantages = []
old_log_probs_list = []
model_versions = []
for sample in batch:
ids = sample.prompt_ids + sample.completion_ids
prompt_completion_ids.append(ids)
attention_mask.append([1] * len(ids))
c_mask = [0] * len(sample.prompt_ids) + [1] * len(sample.completion_ids)
completion_mask.append(c_mask)
advantages.append(float(sample.advantage))
# NOTE: old_log_probs: 0.0 for prompt positions (masked out anyway), real logprobs for completion
old_lps = [0.0] * len(sample.prompt_ids) + sample.old_log_probs
old_log_probs_list.append(old_lps)
model_versions.append(sample.model_version)
max_len = max(len(x) for x in prompt_completion_ids)
# TODO: vectorize, this should be faster
padded_ids = []
padded_att = []
padded_comp = []
padded_old_lps = []
for i in range(len(prompt_completion_ids)):
diff = max_len - len(prompt_completion_ids[i])
padded_ids.append(prompt_completion_ids[i] + [pad_token_id] * diff)
padded_att.append(attention_mask[i] + [0] * diff)
padded_comp.append(completion_mask[i] + [0] * diff)
padded_old_lps.append(old_log_probs_list[i] + [0.0] * diff)
return {
"prompt_completion_ids": torch.tensor(padded_ids, device=device),
"attention_mask": torch.tensor(padded_att, device=device),
"completion_mask": torch.tensor(padded_comp, device=device),
"advantages": torch.tensor(advantages, dtype=torch.float32, device=device),
"old_log_probs": torch.tensor(
padded_old_lps, dtype=torch.float32, device=device
),
"model_versions": torch.tensor(model_versions, dtype=torch.long, device=device),
}
def compute_loss(batch: Batch, model: torch.nn.Module) -> torch.Tensor:
batch_size = batch["prompt_completion_ids"].shape[0]
seq_len = batch["prompt_completion_ids"].shape[1]
completion_tokens = batch["completion_mask"].sum().item()
print(
f"[Loss] Forward pass: batch_size={batch_size}, seq_len={seq_len}, "
f"completion_tokens={int(completion_tokens)}, "
f"advantages={batch['advantages'].tolist()}"
)
t0 = time.time()
outputs = model(
input_ids=batch["prompt_completion_ids"],
attention_mask=batch["attention_mask"],
use_cache=False,
)
fwd_time = time.time() - t0
print(f"[Loss] Forward pass took {fwd_time:.2f}s")
shift_logits = outputs.logits[:, :-1, :]
shift_ids = batch["prompt_completion_ids"][:, 1:]
shift_mask = batch["completion_mask"][:, 1:]
# Current policy log-probs (with gradients, from training model)
token_logprobs = -F.cross_entropy(
shift_logits.transpose(1, 2), shift_ids, reduction="none"
)
# Old policy log-probs (frozen, from vLLM generation policy)
# old_log_probs shape is [batch, seq_len], shift by 1 to align with targets
old_token_logps = batch["old_log_probs"][:, 1:]
ratio = torch.exp(token_logprobs - old_token_logps)
adv = batch["advantages"].unsqueeze(1)
per_sample_loss = (ratio * adv * shift_mask).sum(-1) / shift_mask.sum(-1).clamp(
min=1.0
)
loss = per_sample_loss.mean()
print(
f"[Loss] per_sample_loss={per_sample_loss.detach().tolist()}, "
f"mean_loss={loss.item():.6f}, "
f"ratio_mean={ratio.mean().item():.4f}, ratio_std={ratio.std().item():.4f}"
)
return loss
def get_rollout_samples(
global_batch_size: int,
current_model_version: int,
max_staleness: int,
wait_timeout: float = 60,
) -> list[RolloutSample]:
batch = []
start_time = time.time()
while len(batch) < global_batch_size:
try:
sample = ROLLOUT_QUEUE.get(timeout=1.0)
staleness = current_model_version - sample.model_version
if staleness > max_staleness:
print(
f"[Admission Control] Dropping stale sample "
f"(sample_version={sample.model_version}, "
f"current={current_model_version}, "
f"staleness={staleness} > {max_staleness})"
)
continue
batch.append(sample)
except queue.Empty:
if time.time() - start_time > wait_timeout:
break
continue
return batch
def wait_vllm_server(server_process: subprocess.Popen, vllm_port: int):
import requests
url = f"http://localhost:{vllm_port}/health"
ready = False
max_wait = 300
start_wait = time.time()
while time.time() - start_wait < max_wait:
try:
if requests.get(url, timeout=2).status_code == 200:
print("vLLM server is ready!")
ready = True
break
except Exception:
pass
# Check if the subprocess has exited
if server_process.poll() is not None:
raise RuntimeError(
f"vLLM server process died unexpectedly (exit code {server_process.returncode})."
)
time.sleep(5)
if not ready:
raise TimeoutError("vLLM server failed to start within timeout.")
def get_total_grad(model):
total_norm_sq = 0.0
for p in model.parameters():
if p.grad is not None:
total_norm_sq += p.grad.data.norm(2).item() ** 2
total_norm = total_norm_sq**0.5
return total_norm
def init_weight_transfer(
server_process: subprocess.Popen, vllm_port: int, device: str
) -> PyNcclCommunicator:
print("Setting up vLLM weight transfer engine...")
torch.cuda.set_device(device)
try:
master_address = get_ip()
master_port = get_open_port()
ws_resp = requests.get(f"http://localhost:{vllm_port}/get_world_size")
ws_resp.raise_for_status()
world_size = ws_resp.json()["world_size"] + 1 # +1 for trainer rank 0 always.
init_success = True
def start_init_weight():
nonlocal init_success
try:
init_resp = requests.post(
f"http://localhost:{vllm_port}/init_weight_transfer_engine",
json={
"init_info": dict(
master_address=master_address,
master_port=master_port,
rank_offset=1,
world_size=world_size,
)
},
)
init_resp.raise_for_status()
except Exception as e:
print(f"[Sync] Failed to update weights on vLLM server: {e}")
init_success = False
t = threading.Thread(target=start_init_weight)
t.start()
# NOTE: API call in separate thread
model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=master_address,
master_port=master_port,
world_size=world_size,
),
)
print("Weight transfer engine initialized successfully.")
return model_update_group
except Exception as e:
print(f"Failed to initialize weight transfer engine: {e}")
server_process.terminate()
raise
def send_trainer_weight(
model_update_group: PyNcclCommunicator, model: torch.nn.Module, vllm_port: int
):
pause_resp = requests.post(
f"http://localhost:{vllm_port}/pause",
params={"mode": "wait"},
)
if pause_resp.status_code != 200:
print(f"[Sync] Warning: pause failed: {pause_resp.text}")
names = []
dtype_names = []
shapes = []
for name, p in model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
sync_success = True
def trigger_update_weights():
nonlocal sync_success
try:
resp = requests.post(
f"http://localhost:{vllm_port}/update_weights",
json={
"update_info": dict(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=True,
)
},
)
resp.raise_for_status()
except Exception as e:
print(f"[Sync] Failed to update weights on vLLM server: {e}")
sync_success = False
# NOTE: vllm could deadlock when it received `/update_weights` request.
t = threading.Thread(target=trigger_update_weights)
t.start()
trainer_args = NCCLTrainerSendWeightsArgs(
group=model_update_group,
packed=True,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=model.named_parameters(),
trainer_args=trainer_args,
)
t.join()
resume_resp = requests.post(f"http://localhost:{vllm_port}/resume")
if resume_resp.status_code != 200:
print(f"[Sync] Warning: resume failed: {resume_resp.text}")
return sync_success
def main():
config = TrainingConfig()
print("Starting vLLM server process...")
server_process = run_vllm_server(
config.model_name_or_path, config.vllm_port, gpu_id=1
)
wait_vllm_server(server_process, config.vllm_port)
model_update_group = init_weight_transfer(
server_process, config.vllm_port, config.train_device
)
try:
dataset = load_dataset("open-r1/OpenR1-Math-220k", split="train[:100]")
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
dataloader = AsyncGRPODataLoader(config, tokenizer, dataset)
dataloader.start()
print(f"Loading training model on {config.train_device}...")
model = AutoModelForCausalLM.from_pretrained(config.model_name_or_path)
# NOTE: this is on
model = model.to(device=config.train_device) # type: ignore
model.gradient_checkpointing_enable()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
print("Starting training loop...")
model.train()
current_model_version = 0
for step in range(config.num_steps):
t_fetch = time.time()
print(
f"\n[Fetch] Step {step + 1}: Waiting for {config.global_batch_size} rollout samples (queue size={ROLLOUT_QUEUE.qsize()})..."
)
samples = get_rollout_samples(
config.global_batch_size,
current_model_version=current_model_version,
max_staleness=config.max_staleness,
)
fetch_time = time.time() - t_fetch
if not samples:
print("[Fetch] No samples received after timeout. Stopping.")
break
print(
f"[Fetch] Got {len(samples)} samples in {fetch_time:.2f}s | "
f"avg completion_len={sum(len(s.completion_ids) for s in samples) / len(samples):.0f} | "
f"queue remaining={ROLLOUT_QUEUE.qsize()}"
)
batch = build_batch(
samples, torch.device(config.train_device), tokenizer.pad_token_id
)
print(f"\n{'=' * 60}")
print(
f"[Train] Step {step + 1}/{config.num_steps} | Model Version: {current_model_version}"
)
print(f"[Train] Batch shape: {batch['prompt_completion_ids'].shape}")
print(f"{'=' * 60}")
loss = compute_loss(batch, model)
optimizer.zero_grad()
print("[Backward] Starting backward pass...")
t_bwd = time.time()
loss.backward()
bwd_time = time.time() - t_bwd
print(f"[Backward] Backward pass took {bwd_time:.2f}s")
# Log gradient norms before clipping
total_norm = get_total_grad(model)
print(f"[Backward] Grad norm (pre-clip): {total_norm:.4f}")
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
print(
f"[Train] Step {step + 1} DONE | "
f"Loss: {loss.item():.6f} | "
f"Grad norm (pre-clip): {total_norm:.4f} | "
f"Backward: {bwd_time:.2f}s"
)
# Sync weights with vllm using NCCL
print(
f"[Sync] Syncing weights to vLLM engine (upgrading to version {current_model_version + 1})..."
)
sync_success = send_trainer_weight(
model_update_group, model, config.vllm_port
)
if sync_success:
print("[Sync] Weights synchronized successfully.")
current_model_version += 1
dataloader.update_model_version(current_model_version)
except KeyboardInterrupt:
print("\nReceived keyboard interrupt...")
except Exception as e:
print(f"\nError: {e}")
import traceback
traceback.print_exc()
finally:
print("Terminating vLLM server...")
if "server_process" in locals() and server_process.poll() is None:
server_process.terminate()
try:
server_process.wait(timeout=10)
except subprocess.TimeoutExpired:
server_process.kill()
server_process.wait()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment