Skip to content

Instantly share code, notes, and snippets.

@bbrowning
Last active May 7, 2025 15:26
Show Gist options
  • Save bbrowning/278ec7f5ddbaf3f8e05ca5cbe2801c33 to your computer and use it in GitHub Desktop.
Save bbrowning/278ec7f5ddbaf3f8e05ca5cbe2801c33 to your computer and use it in GitHub Desktop.
diff --git a/tests/v1/engine/test_processor.py b/tests/v1/engine/test_processor.py
new file mode 100644
index 000000000..c793290d4
--- /dev/null
+++ b/tests/v1/engine/test_processor.py
@@ -0,0 +1,196 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+import math
+import time
+from typing import Optional
+
+import pytest
+
+from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig,
+ ModelConfig, VllmConfig)
+from vllm.lora.request import LoRARequest
+from vllm.sampling_params import SamplingParams
+from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
+from vllm.v1.engine.processor import Processor
+
+
+@pytest.fixture
+def opt_125m_huggingface_id():
+ return "facebook/opt-125m"
+
+
+@pytest.fixture
+def fake_lora_adapter_files(tmp_path, opt_125m_huggingface_id):
+ adapter_config = {
+ "base_model_name_or_path": opt_125m_huggingface_id,
+ "bias": "none",
+ "fan_in_fan_out": False,
+ "inference_mode": True,
+ "init_lora_weights": True,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "peft_type": "LORA",
+ "r": 8,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM"
+ }
+ with open(tmp_path / "adapter_config.json", "w") as f:
+ json.dump(adapter_config, f)
+
+ tokenizer_config = {
+ "add_bos_token": True,
+ "add_eos_token": False,
+ "added_tokens_decoder": {
+ "50300": {
+ "content": "_A_",
+ "lstrip": False,
+ "normalized": True,
+ "rstrip": False,
+ "single_word": False,
+ "special": False
+ },
+ "50301": {
+ "content": "_B_",
+ "lstrip": False,
+ "normalized": True,
+ "rstrip": False,
+ "single_word": False,
+ "special": False
+ },
+ },
+ "additional_special_tokens": [],
+ "bos_token": "<s>",
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% endfor %}",
+ "clean_up_tokenization_spaces": False,
+ "eos_token": "</s>",
+ "legacy": True,
+ "model_max_length": 1000000000000000019884624838656,
+ "pad_token": "</s>",
+ "sp_model_kwargs": {},
+ "spaces_between_special_tokens": False,
+ "tokenizer_class": "LlamaTokenizer",
+ "unk_token": "<unk>",
+ "use_default_system_prompt": False
+ }
+ with open(tmp_path / "tokenizer_config.json", "w") as f:
+ json.dump(tokenizer_config, f)
+
+ base_vocab = {f"tok_{i}": i for i in range (3,50300)}
+ added_vocab = {
+ "<unk>": 0,
+ "<s>": 1,
+ "</s>": 2,
+ "_A_": 50300,
+ "_B_": 50301,
+ }
+ vocab = {**base_vocab, **added_vocab}
+ tokenizer = {
+ "version": "1.0",
+ "added_tokens": [
+ {
+ "id": 0,
+ "content": "<unk>",
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": False,
+ "special": True
+ },
+ {
+ "id": 1,
+ "content": "<s>",
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": False,
+ "special": True
+ },
+ {
+ "id": 2,
+ "content": "</s>",
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": False,
+ "special": True
+ },
+ {
+ "id": 50300,
+ "content": "_A_",
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": True,
+ "special": False
+ },
+ {
+ "id": 50301,
+ "content": "_B_",
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": True,
+ "special": False
+ },
+ ],
+ "model": {
+ "type": "BPE",
+ "unk_token": "<unk>",
+ "fuse_unk": True,
+ "byte_fallback": True,
+ "ignore_merges": False,
+ "vocab": vocab,
+ "merges": [],
+ },
+ }
+ with open(tmp_path / "tokenizer.json", "w") as f:
+ json.dump(tokenizer, f)
+
+ with open(tmp_path / "adapter_model.bin", "wb") as f:
+ f.write("".encode("utf-8"))
+
+ return tmp_path
+
+
+def test_allowed_token_ids_with_lora_vocab(opt_125m_huggingface_id, fake_lora_adapter_files):
+ model = opt_125m_huggingface_id
+ model_config = ModelConfig(
+ model=model,
+ task="auto",
+ tokenizer=model,
+ tokenizer_mode="auto",
+ trust_remote_code=True,
+ dtype="float16",
+ seed=42,
+ )
+ cache_config = CacheConfig(
+ block_size=16,
+ gpu_memory_utilization=0.9,
+ swap_space=0,
+ cache_dtype="auto",
+ )
+ device_config = DeviceConfig()
+
+ lora_config = LoRAConfig()
+ vllm_config = VllmConfig(
+ model_config=model_config,
+ cache_config=cache_config,
+ device_config=device_config,
+ lora_config=lora_config,
+ )
+
+ tokenizer = init_tokenizer_from_configs(
+ model_config=vllm_config.model_config,
+ scheduler_config=vllm_config.scheduler_config,
+ lora_config=vllm_config.lora_config)
+ processor = Processor(vllm_config, tokenizer)
+
+ # We define tokens 50300, 50301 in our fake lora adapter
+ lora_token_ids=[50300, 50301]
+ sampling_params = SamplingParams(allowed_token_ids=lora_token_ids)
+ lora_request = LoRARequest("1", 1, str(fake_lora_adapter_files))
+ processor._validate_sampling_params(sampling_params, lora_request)
diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index 27d70a781..2aa19f8bb 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -74,6 +74,7 @@ class Processor:
def _validate_sampling_params(
self,
params: SamplingParams,
+ lora_request: Optional[LoRARequest],
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
@@ -82,7 +83,8 @@ class Processor:
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
- vocab_size = self.model_config.get_vocab_size()
+ tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
+ vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
@@ -122,6 +124,7 @@ class Processor:
def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
+ lora_request: Optional[LoRARequest],
):
"""
Validate supported SamplingParam.
@@ -132,7 +135,7 @@ class Processor:
raise ValueError("V1 does not yet support Pooling models.")
self._validate_logprobs(params)
- self._validate_sampling_params(params)
+ self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params)
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
@@ -207,7 +210,7 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
- self._validate_params(params)
+ self._validate_params(params, lora_request)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment