Last active
May 30, 2025 13:51
-
-
Save DiTo97/05ff268bb640f9d87b3f53d61a302136 to your computer and use it in GitHub Desktop.
ReCall rollout with LangGraph's tool node as execution engine, https://github.com/Agent-RL/ReCall/tree/3d976d26ade4950bc491335bb80da1659424b3cb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import json | |
import os | |
import re | |
import typing | |
import uuid | |
from contextlib import contextmanager | |
from typing import Any | |
import numpy as np | |
import torch | |
import torch.distributed | |
from langchain_core.messages import AIMessage, ToolCall | |
from langchain_core.tools import BaseTool | |
from langgraph.prebuilt import ToolNode | |
from omegaconf import DictConfig | |
from tensordict import TensorDict | |
from verl import DataProto | |
from verl.third_party.vllm import vllm_version | |
from verl.utils.torch_functional import pad_2d_list_to_length, pad_sequence_to_length | |
from verl.workers.rollout.base import BaseRollout | |
from vllm import LLM, SamplingParams | |
from vllm.distributed import parallel_state as vllm_ps | |
T = typing.TypeVar("T") | |
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]: | |
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] | |
token_ids = prompt_token_ids[non_pad_index:].tolist() | |
return token_ids | |
def _repeat_interleave(value: torch.Tensor | np.ndarray, repeats: int) -> torch.Tensor | list[Any]: | |
if isinstance(value, torch.Tensor): | |
return value.repeat_interleave(repeats, dim=0) | |
else: | |
return np.repeat(value, repeats, axis=0) | |
def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): | |
""" | |
end of sentence token can be int or list: 1 or [1, 2] | |
e.g. | |
response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0], | |
[78, 0, 76, 2, 1, 0, 0], | |
[23, 98, 1, 0, 0, 0, 0], | |
[33, 3, 98, 45, 1, 0, 0]]) | |
#eos_token=1 | |
response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 1, 1, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 0, 0]]) | |
#eos_token=[1,2] | |
response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 0, 0]]) | |
""" | |
eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int() | |
return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype) | |
def deserialize_tool_call(string: str) -> ToolCall: | |
message = json.loads(string) | |
assert isinstance(message, dict) | |
request = { | |
'name': message['name'], | |
'args': message.get('arguments', {}), | |
'id': str(uuid.uuid4()), | |
'type': 'tool_call', | |
} | |
return request | |
def validate_tool_calls(string: str) -> bool: | |
balance = 0 | |
for match in re.finditer(r'</?tool_call>', string): | |
match = match.group() | |
if match == '<tool_call>': | |
if balance > 0: | |
return False | |
balance += 1 | |
else: | |
if balance < 1: | |
return False | |
balance -= 1 | |
return balance == 0 | |
def search_tool_calls(string: str) -> list[str]: | |
if not validate_tool_calls(string): | |
return [] | |
try: | |
pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>' | |
matches = re.finditer(pattern, string, re.DOTALL) | |
return [match.group(1).strip() for match in matches] | |
except Exception: | |
return [] | |
async def execute_tool_calls(tool_runner: ToolNode, B_tool_calls: list[list[str]]) -> list[list[str]]: | |
"""executes a batch of tool calls in parallel using the tool runner.""" | |
B_tool_responses = [[""] * len(_) for _ in B_tool_calls] | |
scheduling = [] | |
tool_calls = [] | |
for i, strings in enumerate(B_tool_calls): | |
for j, string in enumerate(strings): | |
try: | |
tool_calls.append(deserialize_tool_call(string)) | |
scheduling.append((i, j)) | |
except Exception as e: | |
B_tool_responses[i][j] = json.dumps({"status": "error", "content": "tool call must be a JSON object with 'name' and (optional) 'arguments' fields"}) | |
message = AIMessage(content="", tool_calls=tool_calls) | |
tool_responses = await tool_runner.ainvoke([message]) | |
for (i, j), tool_message in zip(scheduling, tool_responses): | |
status, content = tool_message.status, tool_message.content | |
if status == "error": | |
content = content.replace("Error: ", "") | |
content = content.strip() | |
B_tool_responses[i][j] = json.dumps({"status": status, "content": content}) | |
return B_tool_responses | |
def run_coroutine_sync(coroutine: typing.Awaitable[T]) -> T: | |
try: | |
eventloop = asyncio.get_running_loop() | |
except RuntimeError: | |
return asyncio.run(coroutine) | |
else: | |
if eventloop.is_running(): | |
future = asyncio.run_coroutine_threadsafe(coroutine, eventloop) | |
return future.result() | |
else: | |
return eventloop.run_until_complete(coroutine) | |
class vLLMRollout(BaseRollout): | |
def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): | |
"""A vLLM rollout. It requires the module is supported by the vllm. | |
Args: | |
module: module here follows huggingface APIs | |
config: DictConfig | |
tokenizer: the task/model tokenizer | |
model_hf_config: the huggingface config to initiallize the generating model in vllm | |
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group | |
""" | |
super().__init__() | |
self.config = config | |
assert not (not config.enforce_eager and config.free_cache_engine), \ | |
"disable CUDA graph (enforce_eager = False) if free cache engine" | |
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) | |
assert tensor_parallel_size <= torch.distributed.get_world_size(), \ | |
"tensor parallel size should be less than or equal to the world size" | |
max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192) | |
if kwargs.get('train_tp', None) is not None: | |
# deployed with megatron | |
os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0' | |
os.environ['MEGATRON_IMPORT_TIMERS'] = '0' | |
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): | |
train_tp = kwargs.get('train_tp', None) | |
num_tp_per_train_tp = train_tp // tensor_parallel_size | |
vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, | |
num_tp_per_train_tp=num_tp_per_train_tp) | |
else: | |
vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) | |
assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ | |
"model context length should be greater than total sequence length" | |
max_model_len = self.config.max_model_len if self.config.max_model_len \ | |
else config.prompt_length + config.response_length | |
max_model_len = int(max_model_len) | |
if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: | |
raise ValueError('Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ | |
please increase max_num_batched_tokens or disable chunked prefill') | |
trust_remote_code = kwargs.get('trust_remote_code', False) | |
load_format = 'dummy' if config.load_format.startswith('dummy') else config.load_format | |
self.inference_engine = LLM( | |
model=model_path, | |
enable_sleep_mode=True, | |
tensor_parallel_size=tensor_parallel_size, | |
distributed_executor_backend="external_launcher", | |
dtype=config.dtype, | |
enforce_eager=config.enforce_eager, | |
gpu_memory_utilization=config.gpu_memory_utilization, | |
disable_custom_all_reduce=True, | |
disable_mm_preprocessor_cache=True, | |
skip_tokenizer_init=False, | |
max_model_len=max_model_len, | |
load_format=load_format, | |
disable_log_stats=config.disable_log_stats, | |
max_num_batched_tokens=max_num_batched_tokens, | |
enable_chunked_prefill=config.enable_chunked_prefill, | |
enable_prefix_caching=True, | |
trust_remote_code=trust_remote_code, | |
seed=int(os.getenv("RANK", "0")) // tensor_parallel_size, | |
) | |
# Offload vllm model to reduce peak memory usage | |
self.inference_engine.sleep(level=1) | |
kwargs = dict( | |
n=1, | |
logprobs=0, # can be set to 0 and let actor to recompute | |
max_tokens=config.response_length, | |
) | |
# # we may detokenize the result all together later | |
if vllm_version != '0.3.1': | |
kwargs['detokenize'] = False | |
# supporting adding any sampling params from the config file | |
for k in config.keys(): | |
if hasattr(SamplingParams(), str(k)): | |
kwargs[k] = config.get(k) | |
print(f"kwargs: {kwargs}") | |
self.sampling_params = SamplingParams(**kwargs) | |
self.pad_token_id = tokenizer.pad_token_id | |
@contextmanager | |
def update_sampling_params(self, **kwargs): | |
# update sampling params | |
old_sampling_params_args = {} | |
if kwargs: | |
for key, value in kwargs.items(): | |
if hasattr(self.sampling_params, key): | |
old_value = getattr(self.sampling_params, key) | |
old_sampling_params_args[key] = old_value | |
setattr(self.sampling_params, key, value) | |
yield | |
# roll back to previous sampling params | |
# if len(old_sampling_params_args): | |
for key, value in old_sampling_params_args.items(): | |
setattr(self.sampling_params, key, value) | |
@torch.no_grad() | |
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: | |
# rebuild vllm cache engine | |
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: | |
self.inference_engine.init_cache_engine() | |
idx = prompts.batch['input_ids'] # (bs, prompt_length) | |
# left-padded attention_mask | |
attention_mask = prompts.batch['attention_mask'] | |
position_ids = prompts.batch['position_ids'] | |
# used to construct attention_mask | |
eos_token_id = prompts.meta_info['eos_token_id'] | |
batch_size = idx.size(0) | |
non_tensor_batch = prompts.non_tensor_batch | |
if 'raw_prompt_ids' not in non_tensor_batch: | |
non_tensor_batch['raw_prompt_ids'] = np.array( | |
[_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) | |
if batch_size != len(non_tensor_batch['raw_prompt_ids']): | |
raise RuntimeError('vllm sharding manager is not work properly.') | |
if 'multi_modal_data' in non_tensor_batch: | |
vllm_inputs = [] | |
for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'), | |
non_tensor_batch.pop('multi_modal_data')): | |
vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data}) | |
else: | |
vllm_inputs = [{ | |
'prompt_token_ids': raw_prompt_ids | |
} for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')] | |
# ensure the type of `prompt_token_ids` passed to vllm is list[int] | |
# https://github.com/volcengine/verl/pull/772 | |
for input_data in vllm_inputs: | |
if isinstance(input_data['prompt_token_ids'], np.ndarray): | |
input_data['prompt_token_ids'] = input_data['prompt_token_ids'].tolist() | |
elif not isinstance(input_data['prompt_token_ids'], list): | |
raise TypeError( | |
f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") | |
do_sample = prompts.meta_info.get('do_sample', True) | |
is_validate = prompts.meta_info.get('validate', False) | |
if not do_sample: | |
kwargs = { | |
'best_of': 1, | |
'top_p': 1.0, | |
'top_k': -1, | |
'min_p': 0.0, | |
'temperature': 0, | |
'n': 1 # if greedy, only 1 response | |
} | |
elif is_validate: | |
# TODO: try ** | |
kwargs = { | |
'top_k': self.config.val_kwargs.top_k, | |
'top_p': self.config.val_kwargs.top_p, | |
'temperature': self.config.val_kwargs.temperature, | |
'n': 1, # if validate, already repeat in ray_trainer | |
} | |
# users can customize different sampling_params at different run | |
with self.update_sampling_params(**kwargs): | |
outputs = self.inference_engine.generate( | |
prompts=vllm_inputs, # because we have already convert it to prompt token id | |
sampling_params=self.sampling_params, | |
use_tqdm=False) | |
# TODO(sgm): disable logprob when recompute_log_prob is enable | |
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) | |
response = [] | |
for output in outputs: | |
for sample_id in range(len(output.outputs)): | |
response.append(output.outputs[sample_id].token_ids) | |
response = pad_2d_list_to_length(response, self.pad_token_id, | |
max_length=self.config.response_length).to(idx.device) | |
if self.sampling_params.n > 1 and do_sample: | |
idx = _repeat_interleave(idx, self.sampling_params.n) | |
attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) | |
position_ids = _repeat_interleave(position_ids, self.sampling_params.n) | |
batch_size = batch_size * self.sampling_params.n | |
if 'multi_modal_inputs' in non_tensor_batch.keys(): | |
non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'], | |
self.sampling_params.n) | |
seq = torch.cat([idx, response], dim=-1) | |
response_length = response.size(1) | |
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) | |
delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1) | |
if position_ids.dim() == 3: # qwen2vl mrope | |
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) | |
# TODO(sgm): fix position_ids on right_pad | |
# prompt: left pad + response: right pad | |
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] | |
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] | |
response_position_ids = position_ids[:, -1:] + delta_position_id | |
position_ids = torch.cat([position_ids, response_position_ids], dim=-1) | |
response_attention_mask = get_response_mask(response_id=response, | |
eos_token=eos_token_id, | |
dtype=attention_mask.dtype) | |
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) | |
# all the tp ranks should contain the same data here. data in all ranks are valid | |
batch = TensorDict( | |
{ | |
'prompts': idx, | |
'responses': response, | |
'input_ids': seq, # here input_ids become the whole sentences | |
# 'old_log_probs': log_probs, # we will recompute old log prob with actor | |
'attention_mask': attention_mask, | |
'position_ids': position_ids | |
}, | |
batch_size=batch_size) | |
# free vllm cache engine | |
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: | |
self.inference_engine.free_cache_engine() | |
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) | |
class vLLMRolloutWithTool(vLLMRollout): | |
def __init__( | |
self, | |
model_path: str, | |
config: DictConfig, | |
tokenizer, | |
model_hf_config, | |
toolkit: list[BaseTool | typing.Callable[..., Any]], | |
**kwargs | |
): | |
super().__init__(model_path, config, tokenizer, model_hf_config, **kwargs) | |
self.tokenizer = tokenizer | |
self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() | |
self.gen_str = "\n<|im_start|>assistant\n<think>" | |
self.gen_ids = self.tokenizer.encode(self.gen_str) | |
self.tool_runner = ToolNode(toolkit) | |
@torch.no_grad() | |
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: | |
# rebuild vllm cache engine | |
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: | |
self.inference_engine.init_cache_engine() | |
ori_input_ids = prompts.batch['input_ids'] # (bs, prompt_length) | |
# left-padded attention_mask | |
attention_mask = prompts.batch['attention_mask'] | |
position_ids = prompts.batch['position_ids'] | |
# # used to construct attention_mask | |
# eos_token_id = prompts.meta_info['eos_token_id'] | |
batch_size = ori_input_ids.size(0) | |
idx_list = [] | |
# parse idx from torch.Tensor to list[list[str]] | |
for i in range(batch_size): | |
idx_list.append(_pre_process_inputs(self.pad_token_id, ori_input_ids[i])) | |
do_sample = prompts.meta_info.get('do_sample', True) | |
is_validate = prompts.meta_info.get('validate', False) | |
if not do_sample: | |
kwargs = { | |
'best_of': 1, | |
'top_p': 1.0, | |
'top_k': -1, | |
'min_p': 0.0, | |
'temperature': 0, | |
'n': 1 # if greedy, only 1 response | |
} | |
elif is_validate: | |
# TODO: try ** | |
kwargs = { | |
'top_k': self.config.val_kwargs.top_k, | |
'top_p': self.config.val_kwargs.top_p, | |
'temperature': self.config.val_kwargs.temperature, | |
'n': 1, # if validate, already repeat in ray_trainer | |
} | |
with self.update_sampling_params(**kwargs): | |
# prepare n copies for each input | |
curr_inputs = [] | |
for input_ids in idx_list: | |
for _ in range(self.sampling_params.n): | |
curr_inputs.append(input_ids.copy()) | |
init_inputs = [ids.copy() for ids in curr_inputs] | |
# track the status of each input | |
curr_max_tokens = [self.sampling_params.max_tokens] * len(curr_inputs) | |
active_indices = list(range(len(curr_inputs))) | |
# collect the result mask of each rollout, 1 for non-result, 0 for tool call result or pad | |
result_mask_list = [[] for _ in range(len(curr_inputs))] | |
# generate until all inputs are completed | |
for step in range(self.config.max_turns): | |
if len(active_indices) == 0: | |
break | |
# only process the active inputs | |
active_inputs = [curr_inputs[i] for i in active_indices] | |
active_max_tokens = [curr_max_tokens[i] for i in active_indices] | |
with self.update_sampling_params( | |
n=1, | |
max_tokens=min(512, max(active_max_tokens)), | |
stop_token_ids=[151644], | |
top_p=0.99, | |
): # 512 at most, and add <|im_start|> as stop for corner case | |
vllm_inputs = [{ | |
'prompt_token_ids': raw_prompt_ids | |
} for raw_prompt_ids in active_inputs] | |
outputs = self.inference_engine.generate( | |
prompts=vllm_inputs, | |
sampling_params=self.sampling_params, | |
use_tqdm=False | |
) | |
# collect all tool calls | |
tool_calls_list: list[list[str]] = [] | |
call_indices: list[int] = [] | |
# process each output | |
new_active_indices = [] | |
for i, idx in enumerate(active_indices): | |
output_ids = outputs[i].outputs[0].token_ids | |
finish_reason = outputs[i].outputs[0].finish_reason | |
stop_reason = outputs[i].outputs[0].stop_reason | |
if finish_reason == 'stop' and (stop_reason == None or stop_reason == self.tokenizer.pad_token_id): | |
curr_inputs[idx] += output_ids | |
result_mask_list[idx] += [1] * len(output_ids) | |
output_str = self.tokenizer.decode(output_ids) | |
tool_calls = search_tool_calls(output_str) | |
if tool_calls: | |
tool_calls_list.append(tool_calls) | |
call_indices.append(idx) | |
new_active_indices.append(idx) | |
else: | |
pass # no tool calls | |
elif finish_reason == 'length': | |
# output over max tokens | |
curr_inputs[idx] += output_ids | |
result_mask_list[idx] += [1] * len(output_ids) | |
elif finish_reason == 'stop' and stop_reason == 151644: # 151644 is the id of <|im_start|>, is a illigal stop, we stop here | |
curr_inputs[idx] += output_ids | |
result_mask_list[idx] += [1] * len(output_ids) | |
else: | |
raise ValueError(f"unknown stop reason. finish_reason: {finish_reason}, stop_reason: {stop_reason}") | |
# batch process tool calls | |
if tool_calls_list: | |
# Only tp_rank 0 executes the tools | |
if self.tp_rank == 0: | |
tool_responses_list = run_coroutine_sync( | |
execute_tool_calls(self.tool_runner, tool_calls_list) | |
) | |
# Prepare data for broadcasting | |
broadcast_data = { | |
'tool_calls_list': tool_calls_list, | |
'call_indices': call_indices, | |
'tool_responses_list': tool_responses_list | |
} | |
else: | |
broadcast_data = None | |
broadcast_data = vllm_ps._TP.broadcast_object(broadcast_data, src=0) | |
# All ranks process the broadcasted data | |
if broadcast_data is not None: | |
tool_calls_list = broadcast_data['tool_calls_list'] | |
call_indices = broadcast_data['call_indices'] | |
tool_responses_list = broadcast_data['tool_responses_list'] | |
for idx, tool_calls, tool_responses in zip(call_indices, tool_calls_list, tool_responses_list): | |
tool_response_str = '' | |
for call, response in zip(tool_calls, tool_responses): | |
tool_response_str += f"<tool_response>{call}\n{response}\n</tool_response>\n" | |
tool_response_str = "\n<|im_start|>user\n" + tool_response_str + "<|im_end|>" | |
output_ids = self.tokenizer.encode(tool_response_str) | |
curr_inputs[idx] += output_ids | |
result_mask_list[idx] += [0] * len(output_ids) | |
curr_inputs[idx] += self.gen_ids | |
result_mask_list[idx] += [0] * len(self.gen_ids) | |
# check if need to truncate, if yes, truncate, and remove from active; if no, update curr_max_tokens | |
length_checked_active_indices = [] | |
for idx in active_indices: | |
assert len(curr_inputs[idx]) - len(init_inputs[idx]) == len(result_mask_list[idx]), f"curr_inputs: {len(curr_inputs[idx])}, init_inputs: {len(init_inputs[idx])}, result_mask_list: {len(result_mask_list[idx])}" | |
if len(curr_inputs[idx]) - len(init_inputs[idx]) >= self.config.response_length: | |
curr_inputs[idx] = init_inputs[idx] \ | |
+ curr_inputs[idx][len(init_inputs[idx]):len(init_inputs[idx])+self.config.response_length] | |
result_mask_list[idx] = result_mask_list[idx][:self.config.response_length] | |
else: | |
curr_max_tokens[idx] = self.config.response_length - len(curr_inputs[idx]) + len(init_inputs[idx]) | |
if idx in new_active_indices: | |
length_checked_active_indices.append(idx) | |
active_indices = length_checked_active_indices | |
output_ids_list = [] | |
# collect the all rollouts | |
for i, input_ids in enumerate(idx_list): | |
for j in range(self.sampling_params.n): | |
idx = i * self.sampling_params.n + j | |
input_len = len(input_ids) | |
output_ids_list.append(curr_inputs[idx][input_len:]) | |
response_attention_mask_list = [] | |
response_list = [] | |
result_mask_list_padded = [] | |
for output_ids, result_mask in zip(output_ids_list, result_mask_list): | |
assert len(output_ids) == len(result_mask), f"output_ids: {len(output_ids)}, result_mask: {len(result_mask)}" | |
# to tensor | |
response = torch.tensor(output_ids, device=ori_input_ids.device) | |
result_mask = torch.tensor(result_mask, device=ori_input_ids.device) | |
# response attention mask, 1 for valid, 0 for invalid | |
response_attention_mask = torch.ones_like(response, dtype=torch.int64) | |
response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) | |
response_attention_mask_list.append(response_attention_mask) | |
# response, pad to response_length | |
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) | |
response_list.append(response) | |
# result mask, 1 for non-result, 0 for result or pad | |
result_mask = pad_sequence_to_length(result_mask, self.config.response_length, 0) | |
result_mask_list_padded.append(result_mask) | |
response_attention_mask = torch.stack(response_attention_mask_list, dim=0) | |
response = torch.stack(response_list, dim=0) | |
result_mask = torch.stack(result_mask_list_padded, dim=0) | |
if self.config.n > 1 and do_sample: | |
ori_input_ids = ori_input_ids.repeat_interleave(self.config.n, dim=0) | |
attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0) | |
position_ids = position_ids.repeat_interleave(self.config.n, dim=0) | |
batch_size = batch_size * self.config.n | |
seq = torch.cat([ori_input_ids, response], dim=-1) | |
response_length = response.size(1) | |
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) | |
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) | |
# TODO(sgm): fix position_ids on right_pad | |
# prompt: left pad + response: right pad | |
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] | |
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] | |
response_position_ids = position_ids[:, -1:] + delta_position_id | |
position_ids = torch.cat([position_ids, response_position_ids], dim=-1) | |
# concat attenion_mask for input and response | |
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) | |
# result mask: result part is 0, other part is 1 | |
loss_mask = result_mask * response_attention_mask | |
# all the tp ranks should contain the same data here. data in all ranks are valid | |
batch = TensorDict({ | |
'prompts': ori_input_ids, | |
'responses': response, | |
'input_ids': seq, # here input_ids become the whole sentences | |
'attention_mask': attention_mask, | |
'loss_mask': loss_mask, | |
'position_ids': position_ids | |
}, batch_size=batch_size) | |
# free vllm cache engine | |
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: | |
self.inference_engine.free_cache_engine() | |
return DataProto(batch=batch) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
langgraph>0.4,<1 | |
verl[vllm]==0.3.0.post1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment