Skip to content

Instantly share code, notes, and snippets.

@fzyzcjy
Last active April 26, 2026 13:01
Show Gist options
  • Select an option

  • Save fzyzcjy/224c60875da072107c119c9ad195e8d1 to your computer and use it in GitHub Desktop.

Select an option

Save fzyzcjy/224c60875da072107c119c9ad195e8d1 to your computer and use it in GitHub Desktop.
InferenceX bench backend_request_func.py with per-request status print (5 sites patched)
# SPDX-License-Identifier: Apache-2.0
import json
import os
import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import List, Optional, Union
import aiohttp
import huggingface_hub.constants
from tqdm.asyncio import tqdm
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass
class RequestFuncInput:
prompt: str
api_url: str
prompt_len: int
output_len: int
model: str
model_name: Optional[str] = None
best_of: int = 1
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[dict] = None
ignore_eos: bool = False
@dataclass
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
error: str = ""
async def async_request_tgi(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
params = {
"best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
"temperature": 0.01, # TGI does not accept 0.0 temperature.
"top_p": 0.99, # TGI does not accept 1.0 top_p.
"truncate": request_func_input.prompt_len,
# TGI does not accept ignore_eos flag.
}
payload = {
"inputs": request_func_input.prompt,
"parameters": params,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk_bytes = chunk_bytes.decode("utf-8")
# NOTE: Sometimes TGI returns a ping response without
# any data, we should skip it.
if chunk_bytes.startswith(":"):
continue
chunk = chunk_bytes.removeprefix("data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
output.generated_text = data["generated_text"]
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
_err = (output.error or '').replace(chr(10), ' | ').strip(); print(f"[{time.strftime('%H:%M:%S')}] ok={output.success} lat={getattr(output,'latency',0):.2f}s url={request_func_input.api_url} plen={request_func_input.prompt_len} err={_err}", flush=True); pbar.update(1)
return output
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
}
if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
_err = (output.error or '').replace(chr(10), ' | ').strip(); print(f"[{time.strftime('%H:%M:%S')}] ok={output.success} lat={getattr(output,'latency',0):.2f}s url={request_func_input.api_url} plen={request_func_input.prompt_len} err={_err}", flush=True); pbar.update(1)
return output
async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
payload = {
"prompt": request_func_input.prompt,
"max_tokens": request_func_input.output_len,
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
"top_p": 1.0,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
# NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# will use 0 as placeholder.
# See https://github.com/microsoft/DeepSpeed-MII/pull/311
output.ttft = 0
st = time.perf_counter()
try:
async with session.post(url=request_func_input.api_url,
json=payload) as response:
if response.status == 200:
parsed_resp = await response.json()
output.latency = time.perf_counter() - st
output.generated_text = parsed_resp["text"][0]
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
_err = (output.error or '').replace(chr(10), ' | ').strip(); print(f"[{time.strftime('%H:%M:%S')}] ok={output.success} lat={getattr(output,'latency',0):.2f}s url={request_func_input.api_url} plen={request_func_input.prompt_len} err={_err}", flush=True); pbar.update(1)
return output
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
_err = (output.error or '').replace(chr(10), ' | ').strip(); print(f"[{time.strftime('%H:%M:%S')}] ok={output.success} lat={getattr(output,'latency',0):.2f}s url={request_func_input.api_url} plen={request_func_input.prompt_len} err={_err}", flush=True); pbar.update(1)
return output
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"chat/completions"
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"messages": [
{
"role": "user",
"content": content
},
],
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
_err = (output.error or '').replace(chr(10), ' | ').strip(); print(f"[{time.strftime('%H:%M:%S')}] ok={output.success} lat={getattr(output,'latency',0):.2f}s url={request_func_input.api_url} plen={request_func_input.prompt_len} err={_err}", flush=True); pbar.update(1)
return output
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
from modelscope import snapshot_download
model_path = snapshot_download(
model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
return model_path
return pretrained_model_name_or_path
def _fix_tokenizer_for_sglang(tokenizer, model_path):
"""Fix transformers v5 tokenizer to match sglang server-side behavior.
Root cause: transformers v5 (>= 5.0) changed how tokenizers are loaded.
Specifically, LlamaTokenizerFast.__init__ in v5 rebuilds the pre_tokenizer
and decoder from scratch using class-specific components, discarding the
originals from tokenizer.json. For models like DeepSeek-R1 that declare
LlamaTokenizerFast but actually use a ByteLevel/Sequence tokenizer
architecture, v5 incorrectly replaces the original Sequence pre_tokenizer
with Metaspace, and the original ByteLevel decoder with Sequence.
See: https://github.com/sgl-project/sglang/blob/9238bd08a2895fa3b7ec79ea567e5c27ac951343/python/sglang/srt/utils/hf_transformers_utils.py#L836
The sglang server applies fixes for this in hf_transformers_utils.py
(_fix_v5_tokenizer_components and _fix_v5_add_bos_eos_token), but the
benchmark client loads the tokenizer directly via AutoTokenizer without
these fixes. This mismatch causes the client to encode text differently
from the server -- e.g. a 7000-token prompt on the client becomes ~35000
tokens on the server, leading to ~5x TTFT inflation and false performance
regressions in benchmarks.
This function replicates the same fixes so the benchmark client tokenizes
identically to the sglang server. It is a no-op on transformers v4.
"""
import json
from pathlib import Path
backend = getattr(tokenizer, "_tokenizer", None)
if backend is not None:
try:
from tokenizers import Tokenizer as RawTokenizer
tok_file = Path(model_path) / "tokenizer.json"
if tok_file.is_file():
raw = RawTokenizer.from_file(str(tok_file))
raw_pre = type(raw.pre_tokenizer).__name__ if raw.pre_tokenizer else None
loaded_pre = type(backend.pre_tokenizer).__name__ if backend.pre_tokenizer else None
if raw_pre and loaded_pre and raw_pre != loaded_pre:
backend.pre_tokenizer = raw.pre_tokenizer
backend.decoder = raw.decoder
except Exception:
pass
try:
config_file = Path(model_path) / "tokenizer_config.json"
if config_file.is_file():
with open(config_file) as f:
config = json.load(f)
tok_class = config.get("tokenizer_class", "")
bos_eos_classes = {
"LlamaTokenizer", "LlamaTokenizerFast",
"CodeLlamaTokenizer", "CodeLlamaTokenizerFast",
"GemmaTokenizer", "GemmaTokenizerFast", "CohereTokenizerFast",
}
if tok_class in bos_eos_classes:
defaults = {"add_bos_token": True, "add_eos_token": False}
changed = False
for attr in ("add_bos_token", "add_eos_token"):
val = config.get(attr)
if val is None:
val = defaults.get(attr, False)
if getattr(tokenizer, attr, None) != val:
setattr(tokenizer, f"_{attr}", val)
changed = True
if changed and hasattr(tokenizer, "update_post_processor"):
tokenizer.update_post_processor()
except Exception:
pass
return tokenizer
def get_tokenizer(
pretrained_model_name_or_path: str,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path):
pretrained_model_name_or_path = get_model(
pretrained_model_name_or_path)
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError(
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
from vllm.transformers_utils.tokenizer import MistralTokenizer
except ImportError as e:
raise ImportError("MistralTokenizer requires vllm package.\n"
"Please install it with `pip install vllm` "
"to use mistral tokenizer mode.") from e
return MistralTokenizer.from_pretrained(
str(pretrained_model_name_or_path))
else:
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
return _fix_tokenizer_for_sglang(tokenizer, pretrained_model_name_or_path)
ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"tensorrt-llm": async_request_trt_llm,
"scalellm": async_request_openai_completions,
"sglang": async_request_openai_completions,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment