Last active
May 29, 2026 08:59
-
-
Save AmineDiro/5e324f609989d3f0437ea13e69ea489f to your computer and use it in GitHub Desktop.
Async grpo minimal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env -S uv run | |
| # /// script | |
| # requires-python = ">=3.10, <3.13" | |
| # dependencies = [ | |
| # "aiohttp==3.13.3", | |
| # "datasets==4.6.1", | |
| # "hf-transfer==0.1.9", | |
| # "requests==2.32.5", | |
| # "torch==2.10.0", | |
| # "transformers==4.57.6", | |
| # "uvicorn==0.41.0", | |
| # "uvloop==0.22.1", | |
| # "vllm==0.17.0", | |
| # ] | |
| # /// | |
| import asyncio | |
| import os | |
| import queue | |
| import subprocess | |
| import sys | |
| import threading | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Tuple, TypedDict | |
| import aiohttp | |
| import requests | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator | |
| from vllm.distributed.weight_transfer.nccl_engine import ( | |
| NCCLTrainerSendWeightsArgs, | |
| NCCLWeightTransferEngine, | |
| ) | |
| from vllm.utils.network_utils import get_ip, get_open_port | |
| # TODO: depth should be dynamically computed using global_batch_size, admission_control, max_inflight | |
| ROLLOUT_QUEUE = queue.Queue(maxsize=1024) | |
| @dataclass | |
| class TrainingConfig: | |
| model_name_or_path: str = "Qwen/Qwen3-4B" | |
| num_steps: int = 100 | |
| learning_rate: float = 1e-5 | |
| global_batch_size: int = 2 | |
| gradient_accumulation_steps: int = 4 | |
| train_device: str = "cuda:0" | |
| # GRPO specific | |
| group_size: int = 4 | |
| max_inflight: int = 32 | |
| max_staleness: int = 3 | |
| # vllm inference config | |
| vllm_server_url: str = "http://localhost:8000/v1/completions" | |
| vllm_port: int = 8000 | |
| temperature: float = 0.8 | |
| max_tokens: int = 1024 | |
| @dataclass(slots=True) | |
| class RolloutSample: | |
| # TODO: Add some notion of turns, trajectory ... | |
| prompt: list[dict[str, str]] | |
| prompt_ids: list[int] | |
| completion: list[dict[str, str]] | |
| completion_ids: list[int] | |
| advantage: float | |
| old_log_probs: list[float] | |
| model_version: int | |
| # TODO: add metadata | |
| # metrics : dict[str,float] | |
| ## generation_time_s: float | |
| class Batch(TypedDict): | |
| prompt_completion_ids: torch.Tensor | |
| attention_mask: torch.Tensor | |
| completion_mask: torch.Tensor | |
| advantages: torch.Tensor | |
| old_log_probs: torch.Tensor | |
| model_versions: torch.Tensor | |
| def run_vllm_server(model_name: str, port: int, gpu_id: int) -> subprocess.Popen: | |
| env = os.environ.copy() | |
| env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) | |
| env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" | |
| env["VLLM_LOGGING_LEVEL"] = "WARNING" | |
| env["VLLM_SERVER_DEV_MODE"] = "1" | |
| # TODO: should probably have more control using `AsyncLLMEngine` | |
| cmd = [ | |
| sys.executable, | |
| "-m", | |
| "vllm.entrypoints.openai.api_server", | |
| "--model", | |
| model_name, | |
| "--port", | |
| str(port), | |
| "--host", | |
| "0.0.0.0", | |
| "--gpu-memory-utilization", | |
| "0.9", | |
| "--max-model-len", | |
| "4096", | |
| "--tensor-parallel-size", | |
| "1", | |
| "--weight-transfer-config", | |
| '{"backend":"nccl"}', | |
| ] | |
| log_file = open("vllm_server.log", "w") | |
| print(f"[Server] Launching vLLM server on GPU {gpu_id}, port {port}") | |
| print(f"[Server] Command: {' '.join(cmd)}") | |
| print("[Server] Logs are being written to vllm_server.log") | |
| process = subprocess.Popen( | |
| cmd, | |
| env=env, | |
| stdout=log_file, | |
| stderr=log_file, | |
| ) | |
| process._log_file = log_file # type: ignore | |
| return process | |
| class AsyncGRPODataLoader: | |
| def __init__(self, config: TrainingConfig, tokenizer: AutoTokenizer, dataset): | |
| self.config = config | |
| self.tokenizer = tokenizer | |
| self.dataset = dataset | |
| self.dataset_iter = iter(dataset) | |
| # TODO: should be event | |
| self.running = False | |
| self.thread = None | |
| self.loop = None | |
| self.model_version = 0 | |
| def update_model_version(self, version: int): | |
| self.model_version = version | |
| def start(self): | |
| self.running = True | |
| self.thread = threading.Thread(target=self._run_event_loop, daemon=True) | |
| self.thread.start() | |
| def _run_event_loop(self): | |
| self.loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(self.loop) | |
| try: | |
| self.loop.run_until_complete(self._generation_loop()) | |
| finally: | |
| self.loop.close() | |
| def _get_next_prompt_data(self) -> Tuple[str, list[int]]: | |
| try: | |
| sample = next(self.dataset_iter) | |
| except StopIteration: | |
| self.dataset_iter = iter(self.dataset) | |
| sample = next(self.dataset_iter) | |
| prompt_text = sample.get("problem", "") | |
| if not prompt_text and "prompt" in sample: | |
| prompt_text = sample["prompt"] | |
| if self.tokenizer.chat_template: | |
| messages = [{"role": "user", "content": prompt_text}] | |
| prompt_str = self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| else: | |
| prompt_str = prompt_text | |
| prompt_ids = self.tokenizer.encode(prompt_str, add_special_tokens=False) | |
| return prompt_text, prompt_ids | |
| async def _generate_with_group_id( | |
| self, session, group_id: int, prompt_ids: list[int] | |
| ): | |
| result = await self._generate_one(session, prompt_ids) | |
| return group_id, result | |
| def _process_completed_group(self, group_data: dict): | |
| prompt_text = group_data["prompt_text"] | |
| prompt_ids = group_data["prompt_ids"] | |
| version_at_generation = group_data["version"] | |
| results = group_data["results"] | |
| # Compute Rewards | |
| rewards = [compute_reward(prompt_text, r[0]) for r in results] | |
| rewards_tensor = torch.tensor(rewards, dtype=torch.float32) | |
| # Compute Advantages | |
| mean_r = rewards_tensor.mean() | |
| std_r = rewards_tensor.std() | |
| if std_r.item() < 1e-6: | |
| advantages = torch.zeros_like(rewards_tensor) | |
| else: | |
| advantages = (rewards_tensor - mean_r) / (std_r + 1e-4) | |
| for i, ( | |
| completion_text, | |
| completion_ids, | |
| completion_logprobs, | |
| ) in enumerate(results): | |
| if not completion_ids: | |
| continue # skip failed generations | |
| rollout = RolloutSample( | |
| prompt=[{"role": "user", "content": prompt_text}], | |
| prompt_ids=prompt_ids, | |
| completion=[{"role": "assistant", "content": completion_text}], | |
| completion_ids=completion_ids, | |
| advantage=advantages[i].item(), | |
| old_log_probs=completion_logprobs, | |
| model_version=version_at_generation, | |
| ) | |
| try: | |
| ROLLOUT_QUEUE.put_nowait(rollout) | |
| print( | |
| f"[Rollout] prompt_len={len(prompt_ids)}, " | |
| f"completion_len={len(completion_ids)}, " | |
| f"total_len={len(prompt_ids) + len(completion_ids)}, " | |
| f"reward={rewards[i]:.4f}, " | |
| f"advantage={advantages[i].item():.4f}, " | |
| f"version={version_at_generation}" | |
| ) | |
| except queue.Full: | |
| pass | |
| async def _generation_loop(self): | |
| pending_tasks = set() | |
| groups = {} | |
| group_id_counter = 0 | |
| async with aiohttp.ClientSession() as session: | |
| while self.running: | |
| while ( | |
| not ROLLOUT_QUEUE.full() | |
| and len(pending_tasks) + self.config.group_size | |
| <= self.config.max_inflight | |
| ): | |
| prompt_text, prompt_ids = self._get_next_prompt_data() | |
| group_id = group_id_counter | |
| group_id_counter += 1 | |
| groups[group_id] = { | |
| "prompt_text": prompt_text, | |
| "prompt_ids": prompt_ids, | |
| "version": self.model_version, | |
| "results": [], | |
| } | |
| for _ in range(self.config.group_size): | |
| task = asyncio.create_task( | |
| self._generate_with_group_id(session, group_id, prompt_ids) | |
| ) | |
| pending_tasks.add(task) | |
| if not pending_tasks: | |
| if ROLLOUT_QUEUE.full(): | |
| await asyncio.sleep(0.1) | |
| continue | |
| done, pending_tasks = await asyncio.wait( | |
| pending_tasks, return_when=asyncio.FIRST_COMPLETED | |
| ) | |
| for task in done: | |
| group_id, result = task.result() | |
| groups[group_id]["results"].append(result) | |
| if len(groups[group_id]["results"]) == self.config.group_size: | |
| group_data = groups.pop(group_id) | |
| self._process_completed_group(group_data) | |
| async def _generate_one( | |
| self, session, prompt_ids: list[int] | |
| ) -> Tuple[str, list[int], list[float]]: | |
| payload = { | |
| "model": self.config.model_name_or_path, | |
| "prompt": prompt_ids, | |
| "max_tokens": self.config.max_tokens, | |
| "temperature": self.config.temperature, | |
| "n": 1, | |
| "return_token_ids": True, | |
| "logprobs": 1, | |
| } | |
| try: | |
| t0 = time.time() | |
| async with session.post(self.config.vllm_server_url, json=payload) as resp: | |
| if resp.status == 200: | |
| data = await resp.json() | |
| choice = data["choices"][0] | |
| text = choice["text"] | |
| completion_ids = choice["token_ids"] | |
| completion_logprobs = choice["logprobs"]["token_logprobs"] | |
| elapsed = time.time() - t0 | |
| print( | |
| f"[Generation] vLLM response in {elapsed:.2f}s, " | |
| f"output_chars={len(text)}" | |
| ) | |
| return text, completion_ids, completion_logprobs | |
| else: | |
| error_text = await resp.text() | |
| print( | |
| f"[Generation] vLLM error status={resp.status}: {error_text[:200]}" | |
| ) | |
| return "", [], [] | |
| except Exception as e: | |
| print(f"[Generation] vLLM request failed: {e}") | |
| return "", [], [] | |
| def compute_reward(prompt: str, completion: str) -> float: | |
| # Dummy reward: length | |
| return float(len(completion)) / 100.0 | |
| def build_batch( | |
| batch: list[RolloutSample], device: torch.device, pad_token_id: int | |
| ) -> Batch: | |
| prompt_completion_ids = [] | |
| attention_mask = [] | |
| completion_mask = [] | |
| advantages = [] | |
| old_log_probs_list = [] | |
| model_versions = [] | |
| for sample in batch: | |
| ids = sample.prompt_ids + sample.completion_ids | |
| prompt_completion_ids.append(ids) | |
| attention_mask.append([1] * len(ids)) | |
| c_mask = [0] * len(sample.prompt_ids) + [1] * len(sample.completion_ids) | |
| completion_mask.append(c_mask) | |
| advantages.append(float(sample.advantage)) | |
| # NOTE: old_log_probs: 0.0 for prompt positions (masked out anyway), real logprobs for completion | |
| old_lps = [0.0] * len(sample.prompt_ids) + sample.old_log_probs | |
| old_log_probs_list.append(old_lps) | |
| model_versions.append(sample.model_version) | |
| max_len = max(len(x) for x in prompt_completion_ids) | |
| # TODO: vectorize, this should be faster | |
| padded_ids = [] | |
| padded_att = [] | |
| padded_comp = [] | |
| padded_old_lps = [] | |
| for i in range(len(prompt_completion_ids)): | |
| diff = max_len - len(prompt_completion_ids[i]) | |
| padded_ids.append(prompt_completion_ids[i] + [pad_token_id] * diff) | |
| padded_att.append(attention_mask[i] + [0] * diff) | |
| padded_comp.append(completion_mask[i] + [0] * diff) | |
| padded_old_lps.append(old_log_probs_list[i] + [0.0] * diff) | |
| return { | |
| "prompt_completion_ids": torch.tensor(padded_ids, device=device), | |
| "attention_mask": torch.tensor(padded_att, device=device), | |
| "completion_mask": torch.tensor(padded_comp, device=device), | |
| "advantages": torch.tensor(advantages, dtype=torch.float32, device=device), | |
| "old_log_probs": torch.tensor( | |
| padded_old_lps, dtype=torch.float32, device=device | |
| ), | |
| "model_versions": torch.tensor(model_versions, dtype=torch.long, device=device), | |
| } | |
| def compute_loss(batch: Batch, model: torch.nn.Module) -> torch.Tensor: | |
| batch_size = batch["prompt_completion_ids"].shape[0] | |
| seq_len = batch["prompt_completion_ids"].shape[1] | |
| completion_tokens = batch["completion_mask"].sum().item() | |
| print( | |
| f"[Loss] Forward pass: batch_size={batch_size}, seq_len={seq_len}, " | |
| f"completion_tokens={int(completion_tokens)}, " | |
| f"advantages={batch['advantages'].tolist()}" | |
| ) | |
| t0 = time.time() | |
| outputs = model( | |
| input_ids=batch["prompt_completion_ids"], | |
| attention_mask=batch["attention_mask"], | |
| use_cache=False, | |
| ) | |
| fwd_time = time.time() - t0 | |
| print(f"[Loss] Forward pass took {fwd_time:.2f}s") | |
| shift_logits = outputs.logits[:, :-1, :] | |
| shift_ids = batch["prompt_completion_ids"][:, 1:] | |
| shift_mask = batch["completion_mask"][:, 1:] | |
| # Current policy log-probs (with gradients, from training model) | |
| token_logprobs = -F.cross_entropy( | |
| shift_logits.transpose(1, 2), shift_ids, reduction="none" | |
| ) | |
| # Old policy log-probs (frozen, from vLLM generation policy) | |
| # old_log_probs shape is [batch, seq_len], shift by 1 to align with targets | |
| old_token_logps = batch["old_log_probs"][:, 1:] | |
| ratio = torch.exp(token_logprobs - old_token_logps) | |
| adv = batch["advantages"].unsqueeze(1) | |
| per_sample_loss = (ratio * adv * shift_mask).sum(-1) / shift_mask.sum(-1).clamp( | |
| min=1.0 | |
| ) | |
| loss = per_sample_loss.mean() | |
| print( | |
| f"[Loss] per_sample_loss={per_sample_loss.detach().tolist()}, " | |
| f"mean_loss={loss.item():.6f}, " | |
| f"ratio_mean={ratio.mean().item():.4f}, ratio_std={ratio.std().item():.4f}" | |
| ) | |
| return loss | |
| def get_rollout_samples( | |
| global_batch_size: int, | |
| current_model_version: int, | |
| max_staleness: int, | |
| wait_timeout: float = 60, | |
| ) -> list[RolloutSample]: | |
| batch = [] | |
| start_time = time.time() | |
| while len(batch) < global_batch_size: | |
| try: | |
| sample = ROLLOUT_QUEUE.get(timeout=1.0) | |
| staleness = current_model_version - sample.model_version | |
| if staleness > max_staleness: | |
| print( | |
| f"[Admission Control] Dropping stale sample " | |
| f"(sample_version={sample.model_version}, " | |
| f"current={current_model_version}, " | |
| f"staleness={staleness} > {max_staleness})" | |
| ) | |
| continue | |
| batch.append(sample) | |
| except queue.Empty: | |
| if time.time() - start_time > wait_timeout: | |
| break | |
| continue | |
| return batch | |
| def wait_vllm_server(server_process: subprocess.Popen, vllm_port: int): | |
| import requests | |
| url = f"http://localhost:{vllm_port}/health" | |
| ready = False | |
| max_wait = 300 | |
| start_wait = time.time() | |
| while time.time() - start_wait < max_wait: | |
| try: | |
| if requests.get(url, timeout=2).status_code == 200: | |
| print("vLLM server is ready!") | |
| ready = True | |
| break | |
| except Exception: | |
| pass | |
| # Check if the subprocess has exited | |
| if server_process.poll() is not None: | |
| raise RuntimeError( | |
| f"vLLM server process died unexpectedly (exit code {server_process.returncode})." | |
| ) | |
| time.sleep(5) | |
| if not ready: | |
| raise TimeoutError("vLLM server failed to start within timeout.") | |
| def get_total_grad(model): | |
| total_norm_sq = 0.0 | |
| for p in model.parameters(): | |
| if p.grad is not None: | |
| total_norm_sq += p.grad.data.norm(2).item() ** 2 | |
| total_norm = total_norm_sq**0.5 | |
| return total_norm | |
| def init_weight_transfer( | |
| server_process: subprocess.Popen, vllm_port: int, device: str | |
| ) -> PyNcclCommunicator: | |
| print("Setting up vLLM weight transfer engine...") | |
| torch.cuda.set_device(device) | |
| try: | |
| master_address = get_ip() | |
| master_port = get_open_port() | |
| ws_resp = requests.get(f"http://localhost:{vllm_port}/get_world_size") | |
| ws_resp.raise_for_status() | |
| world_size = ws_resp.json()["world_size"] + 1 # +1 for trainer rank 0 always. | |
| init_success = True | |
| def start_init_weight(): | |
| nonlocal init_success | |
| try: | |
| init_resp = requests.post( | |
| f"http://localhost:{vllm_port}/init_weight_transfer_engine", | |
| json={ | |
| "init_info": dict( | |
| master_address=master_address, | |
| master_port=master_port, | |
| rank_offset=1, | |
| world_size=world_size, | |
| ) | |
| }, | |
| ) | |
| init_resp.raise_for_status() | |
| except Exception as e: | |
| print(f"[Sync] Failed to update weights on vLLM server: {e}") | |
| init_success = False | |
| t = threading.Thread(target=start_init_weight) | |
| t.start() | |
| # NOTE: API call in separate thread | |
| model_update_group = NCCLWeightTransferEngine.trainer_init( | |
| dict( | |
| master_address=master_address, | |
| master_port=master_port, | |
| world_size=world_size, | |
| ), | |
| ) | |
| print("Weight transfer engine initialized successfully.") | |
| return model_update_group | |
| except Exception as e: | |
| print(f"Failed to initialize weight transfer engine: {e}") | |
| server_process.terminate() | |
| raise | |
| def send_trainer_weight( | |
| model_update_group: PyNcclCommunicator, model: torch.nn.Module, vllm_port: int | |
| ): | |
| pause_resp = requests.post( | |
| f"http://localhost:{vllm_port}/pause", | |
| params={"mode": "wait"}, | |
| ) | |
| if pause_resp.status_code != 200: | |
| print(f"[Sync] Warning: pause failed: {pause_resp.text}") | |
| names = [] | |
| dtype_names = [] | |
| shapes = [] | |
| for name, p in model.named_parameters(): | |
| names.append(name) | |
| dtype_names.append(str(p.dtype).split(".")[-1]) | |
| shapes.append(list(p.shape)) | |
| sync_success = True | |
| def trigger_update_weights(): | |
| nonlocal sync_success | |
| try: | |
| resp = requests.post( | |
| f"http://localhost:{vllm_port}/update_weights", | |
| json={ | |
| "update_info": dict( | |
| names=names, | |
| dtype_names=dtype_names, | |
| shapes=shapes, | |
| packed=True, | |
| ) | |
| }, | |
| ) | |
| resp.raise_for_status() | |
| except Exception as e: | |
| print(f"[Sync] Failed to update weights on vLLM server: {e}") | |
| sync_success = False | |
| # NOTE: vllm could deadlock when it received `/update_weights` request. | |
| t = threading.Thread(target=trigger_update_weights) | |
| t.start() | |
| trainer_args = NCCLTrainerSendWeightsArgs( | |
| group=model_update_group, | |
| packed=True, | |
| ) | |
| NCCLWeightTransferEngine.trainer_send_weights( | |
| iterator=model.named_parameters(), | |
| trainer_args=trainer_args, | |
| ) | |
| t.join() | |
| resume_resp = requests.post(f"http://localhost:{vllm_port}/resume") | |
| if resume_resp.status_code != 200: | |
| print(f"[Sync] Warning: resume failed: {resume_resp.text}") | |
| return sync_success | |
| def main(): | |
| config = TrainingConfig() | |
| print("Starting vLLM server process...") | |
| server_process = run_vllm_server( | |
| config.model_name_or_path, config.vllm_port, gpu_id=1 | |
| ) | |
| wait_vllm_server(server_process, config.vllm_port) | |
| model_update_group = init_weight_transfer( | |
| server_process, config.vllm_port, config.train_device | |
| ) | |
| try: | |
| dataset = load_dataset("open-r1/OpenR1-Math-220k", split="train[:100]") | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| dataloader = AsyncGRPODataLoader(config, tokenizer, dataset) | |
| dataloader.start() | |
| print(f"Loading training model on {config.train_device}...") | |
| model = AutoModelForCausalLM.from_pretrained(config.model_name_or_path) | |
| # NOTE: this is on | |
| model = model.to(device=config.train_device) # type: ignore | |
| model.gradient_checkpointing_enable() | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |
| print("Starting training loop...") | |
| model.train() | |
| current_model_version = 0 | |
| for step in range(config.num_steps): | |
| t_fetch = time.time() | |
| print( | |
| f"\n[Fetch] Step {step + 1}: Waiting for {config.global_batch_size} rollout samples (queue size={ROLLOUT_QUEUE.qsize()})..." | |
| ) | |
| samples = get_rollout_samples( | |
| config.global_batch_size, | |
| current_model_version=current_model_version, | |
| max_staleness=config.max_staleness, | |
| ) | |
| fetch_time = time.time() - t_fetch | |
| if not samples: | |
| print("[Fetch] No samples received after timeout. Stopping.") | |
| break | |
| print( | |
| f"[Fetch] Got {len(samples)} samples in {fetch_time:.2f}s | " | |
| f"avg completion_len={sum(len(s.completion_ids) for s in samples) / len(samples):.0f} | " | |
| f"queue remaining={ROLLOUT_QUEUE.qsize()}" | |
| ) | |
| batch = build_batch( | |
| samples, torch.device(config.train_device), tokenizer.pad_token_id | |
| ) | |
| print(f"\n{'=' * 60}") | |
| print( | |
| f"[Train] Step {step + 1}/{config.num_steps} | Model Version: {current_model_version}" | |
| ) | |
| print(f"[Train] Batch shape: {batch['prompt_completion_ids'].shape}") | |
| print(f"{'=' * 60}") | |
| loss = compute_loss(batch, model) | |
| optimizer.zero_grad() | |
| print("[Backward] Starting backward pass...") | |
| t_bwd = time.time() | |
| loss.backward() | |
| bwd_time = time.time() - t_bwd | |
| print(f"[Backward] Backward pass took {bwd_time:.2f}s") | |
| # Log gradient norms before clipping | |
| total_norm = get_total_grad(model) | |
| print(f"[Backward] Grad norm (pre-clip): {total_norm:.4f}") | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| print( | |
| f"[Train] Step {step + 1} DONE | " | |
| f"Loss: {loss.item():.6f} | " | |
| f"Grad norm (pre-clip): {total_norm:.4f} | " | |
| f"Backward: {bwd_time:.2f}s" | |
| ) | |
| # Sync weights with vllm using NCCL | |
| print( | |
| f"[Sync] Syncing weights to vLLM engine (upgrading to version {current_model_version + 1})..." | |
| ) | |
| sync_success = send_trainer_weight( | |
| model_update_group, model, config.vllm_port | |
| ) | |
| if sync_success: | |
| print("[Sync] Weights synchronized successfully.") | |
| current_model_version += 1 | |
| dataloader.update_model_version(current_model_version) | |
| except KeyboardInterrupt: | |
| print("\nReceived keyboard interrupt...") | |
| except Exception as e: | |
| print(f"\nError: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| finally: | |
| print("Terminating vLLM server...") | |
| if "server_process" in locals() and server_process.poll() is None: | |
| server_process.terminate() | |
| try: | |
| server_process.wait(timeout=10) | |
| except subprocess.TimeoutExpired: | |
| server_process.kill() | |
| server_process.wait() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment