Created
March 27, 2026 20:57
-
-
Save Jordanh1996/af56156da6cfab82917215dd340f76ac to your computer and use it in GitHub Desktop.
Reproduction: SummarizationMiddleware token underestimation with ChatAnthropicVertex
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Reproduction: SummarizationMiddleware token underestimation with ChatAnthropicVertex. | |
| LangChain's _get_approximate_token_counter checks model._llm_type == "anthropic-chat", | |
| but ChatAnthropicVertex._llm_type returns "anthropic-chat-vertexai". This causes the | |
| token counter to use 4.0 chars/token instead of 3.3, underestimating by ~16%. | |
| The summarization middleware never triggers, and the API rejects the prompt. | |
| Additionally: | |
| - use_usage_metadata_scaling is gated on response_metadata["model_provider"], which | |
| ChatAnthropicVertex never sets. The scaling safety net is a no-op. | |
| - _should_summarize_based_on_reported_tokens fails for the same reason. | |
| Two test cases: | |
| A) Gradual conversation growth (multi-turn) | |
| B) Single large message | |
| Usage: | |
| pip install langchain langchain-core langchain-google-vertexai python-dotenv | |
| # Set your Vertex AI credentials: | |
| export VERTEX_AI_PROJECT=your-gcp-project | |
| export VERTEX_AI_LOCATION=us-east5 # or your region | |
| python repro_prompt_too_long.py | |
| # Run only one test: | |
| python repro_prompt_too_long.py --test gradual | |
| python repro_prompt_too_long.py --test single | |
| Exit codes: | |
| 0 = bug reproduced (prompt-too-long error despite middleware thinking it's under limit) | |
| 1 = bug NOT reproduced | |
| Tested with: | |
| langchain==1.2.13, langchain-core==1.2.22, langchain-google-vertexai==3.2.2 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| from functools import partial | |
| from typing import Any | |
| import langchain.agents.middleware.summarization as _summarization_module | |
| from langchain.agents.middleware.summarization import ( | |
| SummarizationMiddleware as LCSummarizationMiddleware, | |
| ) | |
| from langchain_core.messages import AIMessage, AnyMessage, HumanMessage | |
| from langchain_core.messages.utils import count_tokens_approximately | |
| from langchain_google_vertexai.model_garden import ChatAnthropicVertex | |
| logging.basicConfig(level=logging.INFO, format="%(message)s") | |
| logger = logging.getLogger(__name__) | |
| # -- Constants -- | |
| MODEL_NAME = "claude-haiku-4-5" | |
| PROJECT = os.environ["VERTEX_AI_PROJECT"] | |
| LOCATION = os.environ["VERTEX_AI_LOCATION"] | |
| MAX_INPUT_TOKENS = 200_000 | |
| TRIGGER_FRACTION = 0.85 | |
| TRIGGER_THRESHOLD = int(MAX_INPUT_TOKENS * TRIGGER_FRACTION) # 170,000 | |
| # Sentence templates with unique IDs and varied vocabulary. | |
| # Important: repeated/identical text compresses to ~6+ chars/token in BPE, masking the bug. | |
| # Varied text with unique identifiers tokenizes at ~3.3 chars/token (realistic for Claude). | |
| _TEMPLATES = [ | |
| "The item {id} requires input from {dept} department by {date}.", | |
| "Assessment {id} identified {n} gaps in the {area} framework.", | |
| "Reviewer {name} checked {n} samples from the {area} population.", | |
| "Finding {id}: {area} status is {status} as of {date}.", | |
| "Action plan {id} targets {area} with {n} items due {date}.", | |
| "Document {id} was last updated on {date} by {name} in {dept}.", | |
| "Request {id} for {area} approved by {name} with {n} conditions.", | |
| "Vendor {name} scored {n} on the {area} evaluation completed {date}.", | |
| "Incident {id} in {area} affected {n} records and was escalated to {dept}.", | |
| "Procedure {id} for {area} ran {n} checks and completed on {date}.", | |
| ] | |
| _DEPTS = ["IT", "HR", "Finance", "Legal", "Operations", "Security", "Engineering", "Support"] | |
| _AREAS = ["access-control", "encryption", "backup", "monitoring", "incident-response", | |
| "change-management", "data-classification", "network-security", "logging", "identity"] | |
| _STATUSES = ["active", "inactive", "in-progress", "needs-review", "resolved"] | |
| def _generate_sentence(index: int) -> str: | |
| t = _TEMPLATES[index % len(_TEMPLATES)] | |
| return t.format( | |
| id=f"ITEM-{index:05d}", | |
| dept=_DEPTS[index % len(_DEPTS)], | |
| date=f"2026-{(index % 12) + 1:02d}-{(index % 28) + 1:02d}", | |
| n=10 + (index * 7) % 91, | |
| area=_AREAS[index % len(_AREAS)], | |
| name=f"User_{100 + (index * 13) % 900}", | |
| status=_STATUSES[index % len(_STATUSES)], | |
| ) | |
| def _generate_varied_text(target_chars: int) -> str: | |
| """Generate varied text that tokenizes at ~3.3 chars/token (realistic for Claude).""" | |
| sentences = [] | |
| i = 0 | |
| total_chars = 0 | |
| while total_chars < target_chars: | |
| s = _generate_sentence(i) | |
| sentences.append(s) | |
| total_chars += len(s) + 1 | |
| i += 1 | |
| return " ".join(sentences)[:target_chars] | |
| def create_model() -> ChatAnthropicVertex: | |
| return ChatAnthropicVertex( | |
| model_name=MODEL_NAME, | |
| project=PROJECT, | |
| location=LOCATION, | |
| max_output_tokens=256, | |
| profile={"max_input_tokens": MAX_INPUT_TOKENS}, | |
| ) | |
| def build_gradual_messages(model: ChatAnthropicVertex) -> list[AnyMessage]: | |
| """Build multi-turn conversation just below the middleware's trigger threshold.""" | |
| token_counter = _summarization_module._get_approximate_token_counter(model) | |
| messages: list[AnyMessage] = [] | |
| turn = 0 | |
| chars_per_msg = 1200 | |
| while True: | |
| user_text = _generate_varied_text(chars_per_msg) | |
| assistant_text = _generate_varied_text(chars_per_msg) | |
| messages.append(HumanMessage(content=f"Turn {turn} question: {user_text}")) | |
| messages.append(AIMessage(content=f"Turn {turn} response: {assistant_text}")) | |
| approx_tokens = token_counter(messages) | |
| if approx_tokens >= TRIGGER_THRESHOLD - 2000: | |
| break | |
| turn += 1 | |
| return messages | |
| def build_single_large_message(model: ChatAnthropicVertex) -> list[AnyMessage]: | |
| """Build a single message just below the middleware's trigger threshold.""" | |
| token_counter = _summarization_module._get_approximate_token_counter(model) | |
| low, high = 500_000, 800_000 | |
| while high - low > 1000: | |
| mid = (low + high) // 2 | |
| text = _generate_varied_text(mid) | |
| approx = token_counter([HumanMessage(content=text)]) | |
| if approx < TRIGGER_THRESHOLD - 1000: | |
| low = mid | |
| else: | |
| high = mid | |
| content = _generate_varied_text(low) | |
| return [HumanMessage(content=content)] | |
| def analyze_token_counts( | |
| model: ChatAnthropicVertex, | |
| messages: list[AnyMessage], | |
| ) -> dict[str, Any]: | |
| """Compare the middleware's token estimate vs. the correct estimate.""" | |
| # What the middleware uses (4.0 chars/token — wrong for Claude via Vertex) | |
| middleware_counter = _summarization_module._get_approximate_token_counter(model) | |
| middleware_count = middleware_counter(messages) | |
| # What it SHOULD use (3.3 chars/token — correct for Claude) | |
| correct_counter = partial( | |
| count_tokens_approximately, | |
| use_usage_metadata_scaling=True, | |
| chars_per_token=3.3, | |
| ) | |
| correct_count = correct_counter(messages) | |
| # Would the middleware trigger summarization? | |
| middleware = LCSummarizationMiddleware( | |
| model=model, | |
| trigger=("fraction", TRIGGER_FRACTION), | |
| keep=("fraction", 0.10), | |
| ) | |
| would_trigger = middleware._should_summarize(messages, middleware_count) | |
| return { | |
| "middleware_count": middleware_count, | |
| "correct_count": correct_count, | |
| "trigger_threshold": TRIGGER_THRESHOLD, | |
| "would_trigger": would_trigger, | |
| "underestimation_pct": round((1 - middleware_count / correct_count) * 100, 1), | |
| "message_count": len(messages), | |
| } | |
| def attempt_api_call(model: ChatAnthropicVertex, messages: list[AnyMessage]) -> dict[str, Any]: | |
| """Make an actual API call to demonstrate the rejection.""" | |
| try: | |
| response = model.invoke(messages) | |
| return { | |
| "error": None, | |
| "actual_tokens": response.usage_metadata.get("input_tokens") if response.usage_metadata else None, | |
| } | |
| except Exception as e: | |
| error_str = str(e) | |
| actual_tokens = None | |
| if "prompt is too long" in error_str: | |
| match = re.search(r"(\d+) tokens", error_str) | |
| if match: | |
| actual_tokens = int(match.group(1)) | |
| return { | |
| "error": error_str, | |
| "actual_tokens": actual_tokens, | |
| } | |
| def run_test( | |
| name: str, | |
| messages: list[AnyMessage], | |
| model: ChatAnthropicVertex, | |
| ) -> bool: | |
| """Run a single test case. Returns True if the bug was reproduced.""" | |
| logger.info("=" * 70) | |
| logger.info("TEST: %s", name) | |
| logger.info("=" * 70) | |
| analysis = analyze_token_counts(model, messages) | |
| logger.info("") | |
| logger.info("Token count analysis:") | |
| logger.info(" Messages: %d", analysis["message_count"]) | |
| logger.info(" Middleware estimate: %d tokens (4.0 chars/token)", analysis["middleware_count"]) | |
| logger.info(" Correct estimate: %d tokens (3.3 chars/token)", analysis["correct_count"]) | |
| logger.info(" Trigger threshold: %d tokens (%.0f%% of %dk)", TRIGGER_THRESHOLD, TRIGGER_FRACTION * 100, MAX_INPUT_TOKENS // 1000) | |
| logger.info(" Underestimation: %.1f%%", analysis["underestimation_pct"]) | |
| logger.info(" Would trigger: %s", analysis["would_trigger"]) | |
| logger.info("") | |
| logger.info("Root cause:") | |
| logger.info(" model._llm_type: %r", model._llm_type) | |
| logger.info(" Expected by LangChain: %r", "anthropic-chat") | |
| logger.info(" Match: %s", model._llm_type == "anthropic-chat") | |
| logger.info("") | |
| logger.info("Making API call...") | |
| result = attempt_api_call(model, messages) | |
| if result["error"] and "prompt is too long" in result["error"]: | |
| logger.info(" API rejected: prompt is too long") | |
| logger.info(" Actual token count (API): %s", result["actual_tokens"]) | |
| if result["actual_tokens"] and analysis["middleware_count"]: | |
| real_underestimation = round((1 - analysis["middleware_count"] / result["actual_tokens"]) * 100, 1) | |
| logger.info(" Real underestimation: %.1f%%", real_underestimation) | |
| logger.info("") | |
| logger.info(" BUG REPRODUCED: Middleware estimated %d tokens (below %d threshold),", | |
| analysis["middleware_count"], TRIGGER_THRESHOLD) | |
| logger.info(" but actual prompt was %s tokens (above %dk limit).", | |
| result["actual_tokens"], MAX_INPUT_TOKENS // 1000) | |
| return True | |
| elif result["error"]: | |
| logger.info(" Unexpected error: %s", result["error"][:200]) | |
| return False | |
| else: | |
| logger.info(" API call succeeded (bug NOT reproduced)") | |
| logger.info(" Actual input tokens: %s", result["actual_tokens"]) | |
| return False | |
| def main() -> int: | |
| parser = argparse.ArgumentParser( | |
| description="Reproduce token counting underestimation with ChatAnthropicVertex + SummarizationMiddleware", | |
| ) | |
| parser.add_argument("--test", choices=["gradual", "single", "both"], default="both", | |
| help="Which test to run (default: both)") | |
| args = parser.parse_args() | |
| model = create_model() | |
| logger.info("Reproduction: ChatAnthropicVertex token counting underestimation") | |
| logger.info("Model: %s | Project: %s | Location: %s", MODEL_NAME, PROJECT, LOCATION) | |
| logger.info("Context window: %dk | Trigger: %.0f%% (%dk)", | |
| MAX_INPUT_TOKENS // 1000, TRIGGER_FRACTION * 100, TRIGGER_THRESHOLD // 1000) | |
| logger.info("") | |
| reproduced = False | |
| tests_to_run = [] | |
| if args.test in ("gradual", "both"): | |
| tests_to_run.append(("Gradual conversation growth", build_gradual_messages)) | |
| if args.test in ("single", "both"): | |
| tests_to_run.append(("Single large message", build_single_large_message)) | |
| for name, builder in tests_to_run: | |
| logger.info("Building messages for: %s ...", name) | |
| messages = builder(model) | |
| if run_test(name, messages, model): | |
| reproduced = True | |
| logger.info("") | |
| logger.info("=" * 70) | |
| if reproduced: | |
| logger.info("RESULT: Bug reproduced. SummarizationMiddleware fails to trigger") | |
| logger.info(" for ChatAnthropicVertex due to _llm_type mismatch.") | |
| logger.info("") | |
| logger.info("The fix: _get_approximate_token_counter should check") | |
| logger.info(" model._llm_type.startswith('anthropic-chat')") | |
| logger.info("instead of") | |
| logger.info(" model._llm_type == 'anthropic-chat'") | |
| return 0 | |
| else: | |
| logger.info("RESULT: Bug NOT reproduced.") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output