Last active
March 29, 2026 07:07
-
-
Save AmineDiro/7bb14e0df24322d8ad4988050c0a24de to your computer and use it in GitHub Desktop.
minimal test for reward
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Minimal sanity-check for AsyncGRPOTrainer: the "Immediate EOS" test. | |
| The model is rewarded with R(y) = -len(completion_tokens). The optimal policy | |
| is to emit <EOS> immediately (reward = -1). Within a handful of steps the | |
| average completion length should drop and reward_mean should climb toward -1. | |
| Start the vLLM server: | |
| CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \ | |
| --weight-transfer-config '{"backend":"nccl"}' \ | |
| --logprobs-mode processed_logprobs | |
| --max-model-len 512 | |
| Run training: | |
| CUDA_VISIBLE_DEVICES=1 python examples/scripts/async_grpo_immediate_eos.py | |
| !/! NOTE: depends on transformers>=5.2.0 and vllm>=17.1.0 | |
| """ | |
| import logging | |
| import os | |
| from datasets import Dataset | |
| from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer | |
| logging.basicConfig( | |
| level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logging.getLogger("trl").setLevel(logging.DEBUG) | |
| def negative_length_reward(completion_ids: list[list[int]], **kwargs) -> list[float]: | |
| """Reward = -len(completion). The optimal policy outputs EOS immediately (reward = -1).""" | |
| return [-len(ids) for ids in completion_ids] | |
| def format_sample(sample): | |
| return {"prompt": [{"role": "user", "content": sample["text"]}]} | |
| def main() -> None: | |
| # A trivial dataset of short prompts — content doesn't matter, the reward ignores it. | |
| dataset = Dataset.from_dict({"text": ["Hello", "Test prompt", "What is 1+1?", "Say something", "Hi there"] * 20}) | |
| dataset = dataset.map(format_sample, remove_columns=dataset.column_names) | |
| config = AsyncGRPOConfig( | |
| output_dir="/tmp/results_immediate_eos", | |
| per_device_train_batch_size=4, | |
| max_completion_length=512, | |
| num_generations=8, | |
| max_steps=400, | |
| logging_steps=1, | |
| report_to="wandb", | |
| log_completions=True, | |
| num_completions_to_print=3, | |
| ) | |
| trainer = AsyncGRPOTrainer( | |
| model="Qwen/Qwen3-0.6B", | |
| args=config, | |
| train_dataset=dataset, | |
| reward_funcs=negative_length_reward, | |
| ) | |
| trainer.train() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment