Skip to content

Instantly share code, notes, and snippets.

@AmineDiro
Last active March 29, 2026 07:07
Show Gist options
  • Select an option

  • Save AmineDiro/7bb14e0df24322d8ad4988050c0a24de to your computer and use it in GitHub Desktop.

Select an option

Save AmineDiro/7bb14e0df24322d8ad4988050c0a24de to your computer and use it in GitHub Desktop.
minimal test for reward
"""
Minimal sanity-check for AsyncGRPOTrainer: the "Immediate EOS" test.
The model is rewarded with R(y) = -len(completion_tokens). The optimal policy
is to emit <EOS> immediately (reward = -1). Within a handful of steps the
average completion length should drop and reward_mean should climb toward -1.
Start the vLLM server:
CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \
--weight-transfer-config '{"backend":"nccl"}' \
--logprobs-mode processed_logprobs
--max-model-len 512
Run training:
CUDA_VISIBLE_DEVICES=1 python examples/scripts/async_grpo_immediate_eos.py
!/! NOTE: depends on transformers>=5.2.0 and vllm>=17.1.0
"""
import logging
import os
from datasets import Dataset
from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
logging.basicConfig(
level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.getLogger("trl").setLevel(logging.DEBUG)
def negative_length_reward(completion_ids: list[list[int]], **kwargs) -> list[float]:
"""Reward = -len(completion). The optimal policy outputs EOS immediately (reward = -1)."""
return [-len(ids) for ids in completion_ids]
def format_sample(sample):
return {"prompt": [{"role": "user", "content": sample["text"]}]}
def main() -> None:
# A trivial dataset of short prompts — content doesn't matter, the reward ignores it.
dataset = Dataset.from_dict({"text": ["Hello", "Test prompt", "What is 1+1?", "Say something", "Hi there"] * 20})
dataset = dataset.map(format_sample, remove_columns=dataset.column_names)
config = AsyncGRPOConfig(
output_dir="/tmp/results_immediate_eos",
per_device_train_batch_size=4,
max_completion_length=512,
num_generations=8,
max_steps=400,
logging_steps=1,
report_to="wandb",
log_completions=True,
num_completions_to_print=3,
)
trainer = AsyncGRPOTrainer(
model="Qwen/Qwen3-0.6B",
args=config,
train_dataset=dataset,
reward_funcs=negative_length_reward,
)
trainer.train()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment