Skip to content

Instantly share code, notes, and snippets.

@mikasenghaas
Last active November 1, 2025 00:47
Show Gist options
  • Select an option

  • Save mikasenghaas/604b2470d86bc1db38971c32260dbb84 to your computer and use it in GitHub Desktop.

Select an option

Save mikasenghaas/604b2470d86bc1db38971c32260dbb84 to your computer and use it in GitHub Desktop.
Benchmarks for constructing OAI Pydantic models w/ many logprobs

To reproduce, download the .json files with raw responses from here.

Running in standard mode

uv run oai_pydantic.py
Processing oai_response_1024.json...
model_validate taken 5.02ms
model_construct taken 5.86ms
model_validate_json taken 2.18ms

Processing oai_response_8192.json...
model_validate taken 12.85ms
model_construct taken 37.14ms
model_validate_json taken 15.86ms

Processing oai_response_16384.json...
model_validate taken 54.59ms
model_construct taken 52.33ms
model_validate_json taken 31.51ms

Processing oai_response_32768.json...
model_validate taken 127.27ms
model_construct taken 133.14ms
model_validate_json taken 63.51ms
  • Indeed, model_validate_json is roughly 2x faster than model_validate and model_construct. That's nice but this would still be a huge bottleneck in RL training with 4k+ batch sizes
  • As @baggiponte said, I am quite surprised that model_validate always takes ~same time as model_construct. Shouldn't the latter not do any validation? Maybe this is not true for submodels? If so, is there a way to also skip validation of submodels - all this data is entirely trusted so skipping validation entirely is fine
  • Nit: I am not quite sure why I am now measuring 130ms instead of the 180ms I got a couple of days ago on the 32k sample, maybe because I am using perf_counter now instead of plain time or the CPU has a better day today, who knows. But really it's about the order of magnitude which is still the same.
  • Also, it does seem like constructing the Pydantic model scales ~linearly with the number of logprobs it has to parse.

Ignoring the logprobs field

uv run oai_pydantic.py --ignore-logprobs
Processing oai_response_1024.json...
model_validate taken 2.97ms
model_construct taken 3.38ms
model_validate_json taken 0.26ms

Processing oai_response_8192.json...
model_validate taken 0.03ms
model_construct taken 0.15ms
model_validate_json taken 0.10ms

Processing oai_response_16384.json...
model_validate taken 0.03ms
model_construct taken 0.13ms
model_validate_json taken 0.16ms

Processing oai_response_32768.json...
model_validate taken 0.03ms
model_construct taken 0.14ms
model_validate_json taken 0.26ms
  • It seems like it's only the logprob parsing that is taking a significant amount of time.

Using the hotfix that we use in prime-rl for now

uv run oai_pydantic.py --use-hotfix
Processing oai_response_1024.json...
model_validate taken 2.38ms
model_construct taken 3.03ms
model_validate_json taken 1.96ms

Processing oai_response_8192.json...
model_validate taken 0.17ms
model_construct taken 0.15ms
model_validate_json taken 13.58ms

Processing oai_response_16384.json...
model_validate taken 0.97ms
model_construct taken 0.14ms
model_validate_json taken 21.81ms

Processing oai_response_32768.json...
model_validate taken 1.83ms
model_construct taken 0.14ms
model_validate_json taken 43.73ms

Our hotfix essentially skips whatever Pydantic does to the logprobs field so we are still quick. Interestingly (I hadn't tested this before), model_validate_json does not seem to profit from it.

Super interested to hear where people think the bottleneck is and if we can find a more elegant general solution!:)

import argparse
import glob
import json
from time import perf_counter
from typing import Any, List, Optional
import openai.types.chat.chat_completion
from openai.types.chat.chat_completion import ChatCompletion, Choice
class ChoiceAny(Choice):
"""Same as openai.types.chat.chat_completion.Choice, but without type validation for logprobs field."""
logprobs: Optional[Any] = None
class ChatCompletionAny(ChatCompletion):
"""Same as openai.types.chat.chat_completion.ChatCompletion, but but using ChoiceAny instead of Choice."""
choices: List[ChoiceAny] # type: ignore
def main(args: argparse.Namespace):
from openai.types.chat.chat_completion import ChatCompletion
oai_response_files = sorted(glob.glob("oai_response_*.json"), key=lambda x: int(x.split(".")[0].split("_")[-1]))
for oai_response_file in oai_response_files:
print(f"Processing {oai_response_file}...")
num_completion_tokens = int(oai_response_file.split(".")[0].split("_")[-1])
with open(oai_response_file, "r") as f:
oai_response = json.load(f)
if args.ignore_logprobs:
for choice in oai_response["choices"]:
choice["logprobs"] = None
oai_response_json = json.dumps(oai_response)
start_time = perf_counter()
completion = ChatCompletion.model_validate(oai_response)
assert completion.usage is not None and completion.usage.completion_tokens == num_completion_tokens
print(f"model_validate taken {1000 * (perf_counter() - start_time):.2f}ms")
start_time = perf_counter()
completion = ChatCompletion.model_construct(**oai_response)
assert completion.usage is not None and completion.usage.completion_tokens == num_completion_tokens
print(f"model_construct taken {1000 * (perf_counter() - start_time):.2f}ms")
start_time = perf_counter()
completion = ChatCompletion.model_validate_json(oai_response_json)
assert completion.usage is not None and completion.usage.completion_tokens == num_completion_tokens
print(f"model_validate_json taken {1000 * (perf_counter() - start_time):.2f}ms")
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ignore-logprobs", action="store_true")
parser.add_argument("--use-hotfix", action="store_true")
args = parser.parse_args()
if args.use_hotfix:
openai.types.chat.chat_completion.ChatCompletion = ChatCompletionAny
main(args)
@samuelcolvin
Copy link
Copy Markdown

Thanks @rahuliyer95, what you've said makes loads of sense.

I've no idea why stainlessapi/openai have that model_construct method and don't use model_validate_json.

In a hurry to respond, I didn't check if it was a vanilla BaseModel.

@mikasenghaas it's also worth noting that you're conflating the time taken to decode the JSON with the time taken to run model_construct - that won't have much effect here since model_construct is so slow, but if you used a saner approach like just calling model_validate, it makes quite a lot of difference:

this code
import json
import time
from pathlib import Path

from openai.types.chat.chat_completion import ChatCompletion

content = Path("oai_response_32768.json").read_bytes()

json_data = json.loads(content)

start = time.perf_counter()
ChatCompletion.model_validate(json_data)
end = time.perf_counter()
print(f"model_validate(json_data) taken: {1000 * (end - start): .2f}ms")

start = time.perf_counter()
ChatCompletion.model_validate(json.loads(content))
end = time.perf_counter()
print(f"model_validate(json.loads(content)) taken: {1000 * (end - start): .2f}ms")

start = time.perf_counter()
ChatCompletion.model_validate_json(content)
end = time.perf_counter()
print(f"model_validate_json taken: {1000 * (end - start): .2f}ms")

gives:

model_validate(json_data) taken:  28.24ms
model_validate(json.loads(content)) taken:  45.23ms
model_validate_json taken:  33.34ms

If you really care about performance but need validation, you can shave ~33% off validation time by using typed dicts and TypeAdapter

type typed adapter example
import time
from pathlib import Path
from typing import Literal, NotRequired, TypedDict

from pydantic import TypeAdapter


class TopLogprob(TypedDict):
    token: str
    bytes: NotRequired[list[int] | None]
    logprob: float


class ChatCompletionTokenLogprob(TypedDict):
    token: str
    bytes: NotRequired[list[int] | None]
    logprob: float
    top_logprobs: list[TopLogprob]


class ChoiceLogprobs(TypedDict):
    content: NotRequired[list[ChatCompletionTokenLogprob] | None]
    """A list of message content tokens with log probability information."""
    refusal: NotRequired[list[ChatCompletionTokenLogprob] | None]
    """A list of message refusal tokens with log probability information."""


class AnnotationURLCitation(TypedDict):
    end_index: int
    """The index of the last character of the URL citation in the message."""
    start_index: int
    """The index of the first character of the URL citation in the message."""
    title: str
    """The title of the web resource."""
    url: str
    """The URL of the web resource."""


class Annotation(TypedDict):
    type: Literal["url_citation"]
    url_citation: AnnotationURLCitation


class ChatCompletionAudio(TypedDict):
    id: str
    data: str
    expires_at: int
    transcript: str


class ChatCompletionMessage(TypedDict):
    content: NotRequired[str | None]
    refusal: NotRequired[str | None]
    role: Literal["assistant"]
    annotations: NotRequired[list[Annotation] | None]
    audio: NotRequired[ChatCompletionAudio | None]


class Choice(TypedDict):
    finish_reason: Literal[
        "stop", "length", "tool_calls", "content_filter", "function_call"
    ]
    index: int
    logprobs: NotRequired[ChoiceLogprobs | None]
    message: ChatCompletionMessage


class CompletionTokensDetails(TypedDict):
    accepted_prediction_tokens: NotRequired[int | None]
    audio_tokens: NotRequired[int | None]
    reasoning_tokens: NotRequired[int | None]
    rejected_prediction_tokens: NotRequired[int | None]


class PromptTokensDetails(TypedDict):
    audio_tokens: NotRequired[int | None]
    cached_tokens: NotRequired[int | None]


class CompletionUsage(TypedDict):
    completion_tokens: int
    prompt_tokens: int
    total_tokens: int
    completion_tokens_details: NotRequired[CompletionTokensDetails | None]
    prompt_tokens_details: NotRequired[PromptTokensDetails | None]


class ChatCompletion(TypedDict):
    id: str
    choices: list[Choice]
    created: int
    model: str
    object: Literal["chat.completion"]
    service_tier: NotRequired[
        Literal["auto", "default", "flex", "scale", "priority"] | None
    ]
    system_fingerprint: NotRequired[str | None]
    usage: NotRequired[CompletionUsage | None]


content = Path("oai_response_32768.json").read_bytes()
ta = TypeAdapter(ChatCompletion)
start = time.perf_counter()
ta.validate_json(content)
end = time.perf_counter()
print(f"model_validate_json taken: {1000 * (end - start): .2f}ms")

gives:

model_validate_json taken:  23.08ms

Or if you just want to parse the JSON, you can almost halve the time again:

import time
from pathlib import Path

import pydantic_core

# warmup
pydantic_core.from_json(b"{}")
content = Path("oai_response_32768.json").read_bytes()
start = time.perf_counter()
pydantic_core.from_json(content)
end = time.perf_counter()
print(f"pydantic_core.from_json(content) taken: {1000 * (end - start): .2f}ms")
#> pydantic_core.from_json(content) taken:  13.07ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment