Last active
January 26, 2026 23:41
-
-
Save 17twenty/ed3146e8296dde813db3346694b5de9b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Mixture-of-Experts Router with LangGraph | |
| An intelligent routing system that selects the cheapest model tier capable of | |
| handling each request, with automatic escalation when quality checks fail. | |
| Features: | |
| - LLM-based routing with confidence scoring | |
| - Graduated escalation ladder (small → reasoning → large) | |
| - LLM judge for quality gating | |
| - Retry with exponential backoff | |
| - Comprehensive cost tracking across all calls | |
| - Configurable via environment variables | |
| Environment Variables: | |
| OPENAI_BASE_URL: Base URL for OpenAI-compatible API | |
| OPENAI_API_KEY: API key | |
| MOE_MAX_ESCALATIONS: Max escalation attempts (default: 1) | |
| MOE_JUDGE_ENABLED: Enable/disable judge (default: true) | |
| MOE_TIMEOUT_SECONDS: Timeout per API call (default: 30) | |
| MOE_DEBUG: Enable debug logging (default: false) | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import json | |
| import random | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import TypedDict, List, Dict, Any, Optional, Literal, Tuple | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeout | |
| from openai import OpenAI | |
| from langgraph.graph import StateGraph, END | |
| # ---------------------------- | |
| # Logging Setup | |
| # ---------------------------- | |
| logging.basicConfig( | |
| level=logging.DEBUG if os.environ.get("MOE_DEBUG", "").lower() == "true" else logging.INFO, | |
| format="%(asctime)s | %(levelname)-8s | %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger("moe_router") | |
| # ---------------------------- | |
| # Configuration | |
| # ---------------------------- | |
| @dataclass | |
| class MoEConfig: | |
| """Configuration for the MoE router, loaded from environment.""" | |
| base_url: str = "" | |
| api_key: str = "" | |
| max_escalations: int = 1 | |
| judge_enabled: bool = True | |
| timeout_seconds: float = 30.0 | |
| max_retries: int = 3 | |
| debug: bool = False | |
| @classmethod | |
| def from_env(cls) -> "MoEConfig": | |
| base_url = os.environ.get("OPENAI_BASE_URL", "") | |
| api_key = os.environ.get("OPENAI_API_KEY", "") | |
| if not base_url or not api_key: | |
| raise RuntimeError( | |
| "Missing required environment variables.\n" | |
| "Set OPENAI_BASE_URL and OPENAI_API_KEY." | |
| ) | |
| return cls( | |
| base_url=base_url, | |
| api_key=api_key, | |
| max_escalations=int(os.environ.get("MOE_MAX_ESCALATIONS", "1")), | |
| judge_enabled=os.environ.get("MOE_JUDGE_ENABLED", "true").lower() == "true", | |
| timeout_seconds=float(os.environ.get("MOE_TIMEOUT_SECONDS", "30")), | |
| max_retries=int(os.environ.get("MOE_MAX_RETRIES", "3")), | |
| debug=os.environ.get("MOE_DEBUG", "").lower() == "true", | |
| ) | |
| # Load config and create client | |
| config = MoEConfig.from_env() | |
| client = OpenAI(base_url=config.base_url, api_key=config.api_key) | |
| # ---------------------------- | |
| # Model Registry | |
| # ---------------------------- | |
| Tier = Literal["small", "reasoning", "code", "large"] | |
| MODEL_REGISTRY: Dict[str, Dict[str, Any]] = { | |
| "router": { | |
| "id": "phi35-mini-instruct", | |
| "max_output_tokens": 250, | |
| "temperature": 0.0, | |
| "relative_cost": 0.2, | |
| }, | |
| "small": { | |
| "id": "phi35-mini-instruct", | |
| "max_output_tokens": 700, | |
| "temperature": 0.2, | |
| "relative_cost": 1.0, | |
| }, | |
| "reasoning": { | |
| "id": "llama31-8b-instruct", | |
| "max_output_tokens": 900, | |
| "temperature": 0.2, | |
| "relative_cost": 2.5, | |
| }, | |
| "code": { | |
| "id": "qwen25-coder-7b-instruct", | |
| "max_output_tokens": 1200, | |
| "temperature": 0.1, | |
| "relative_cost": 2.5, | |
| }, | |
| "large": { | |
| "id": "qwen25-32b-instruct", | |
| "max_output_tokens": 1400, | |
| "temperature": 0.2, | |
| "relative_cost": 6.0, | |
| }, | |
| } | |
| # Escalation paths: current tier → next tier | |
| ESCALATION_MAP: Dict[Tier, Tier] = { | |
| "small": "reasoning", # Try reasoning before jumping to large | |
| "code": "large", # Code failures usually need more capability | |
| "reasoning": "large", # Reasoning failures go to large | |
| "large": "large", # Already at max | |
| } | |
| # ---------------------------- | |
| # Prompts | |
| # ---------------------------- | |
| ROUTER_SYSTEM_PROMPT = """You are a routing classifier for a multi-model system. | |
| Pick the cheapest model tier that will likely succeed. | |
| Tiers (ordered by cost, cheapest first): | |
| - small: Simple tasks with clear answers | |
| Examples: summarization, text rewriting, extraction, factual Q&A, formatting, translation, simple drafting | |
| - code: Programming-centric tasks | |
| Examples: debugging, writing/refactoring code, explaining code, SQL queries, regex, writing tests, stack traces | |
| - reasoning: Tasks requiring multi-step thinking | |
| Examples: math word problems, planning, comparing tradeoffs, logic puzzles, constraint satisfaction, complex analysis | |
| - large: Only when simpler tiers will likely fail | |
| Examples: highly ambiguous requests, tasks requiring broad knowledge synthesis, long-form content with many constraints, high-stakes correctness | |
| Decision rules: | |
| 1. Default to "small" unless the task clearly requires more | |
| 2. If code/programming is central to the task, use "code" | |
| 3. If explicit multi-step reasoning is required (but not code), use "reasoning" | |
| 4. Use "large" sparingly—only for genuinely complex or ambiguous tasks | |
| Return ONLY valid JSON with exactly this schema: | |
| {"tier": "small|code|reasoning|large", "confidence": 0.0-1.0, "reason": "brief explanation"} | |
| No markdown. No extra text. No extra keys.""" | |
| JUDGE_SYSTEM_PROMPT = """You are a QA judge evaluating if an answer adequately addresses a user's request. | |
| Escalate (needs_escalation: true) ONLY if: | |
| - Answer is factually wrong or contains nonsense | |
| - Answer ignores key constraints from the request | |
| - Answer refuses or deflects without good justification | |
| - Answer is substantially incomplete (missing major requested parts) | |
| - Answer shows clear signs of confusion or hallucination | |
| Do NOT escalate for: | |
| - Minor stylistic or formatting issues | |
| - Verbose but correct answers | |
| - Reasonable interpretations of ambiguous requests | |
| - Answers that are correct but could be slightly better | |
| Return ONLY valid JSON: {"needs_escalation": true|false, "reason": "brief explanation"} | |
| No markdown. No extra text.""" | |
| # ---------------------------- | |
| # State Definition | |
| # ---------------------------- | |
| class MoEState(TypedDict, total=False): | |
| # Input | |
| user_text: str | |
| messages: List[Dict[str, str]] # Optional: full conversation history | |
| # Preprocessing features | |
| features: Dict[str, bool] | |
| # Router output | |
| selected_tier: Tier | |
| route_confidence: float | |
| route_reason: str | |
| # Execution | |
| response_text: str | |
| model_used: str | |
| tier_used: Tier | |
| usage: Dict[str, Any] | |
| latency_ms: int | |
| # Escalation | |
| needs_escalation: bool | |
| escalation_reason: Optional[str] | |
| attempts: int | |
| # Telemetry | |
| estimated_cost_units: float | |
| total_tokens: int | |
| trace: List[Dict[str, Any]] | |
| # Error handling | |
| error: Optional[str] | |
| def _log(state: MoEState, node: str, **kv) -> MoEState: | |
| """Append an event to the trace log.""" | |
| state.setdefault("trace", []) | |
| event = {"node": node, "timestamp": time.time(), **kv} | |
| state["trace"].append(event) | |
| logger.debug(f"[{node}] {json.dumps(kv, default=str)}") | |
| return state | |
| # ---------------------------- | |
| # API Call Helper with Retry & Timeout | |
| # ---------------------------- | |
| def _call_openai_compatible( | |
| *, | |
| model: str, | |
| messages: List[Dict[str, str]], | |
| max_output_tokens: int, | |
| temperature: float, | |
| timeout_seconds: Optional[float] = None, | |
| max_retries: Optional[int] = None, | |
| ) -> Tuple[str, Dict[str, Any]]: | |
| """ | |
| Call OpenAI-compatible API with retry logic and timeout. | |
| Tries Responses API first, falls back to Chat Completions. | |
| Returns: (output_text, usage_dict) | |
| """ | |
| timeout = timeout_seconds or config.timeout_seconds | |
| retries = max_retries or config.max_retries | |
| last_error: Optional[Exception] = None | |
| def _do_call() -> Tuple[str, Dict[str, Any]]: | |
| # Try Responses API first (if available) | |
| try: | |
| resp = client.responses.create( | |
| model=model, | |
| input=messages, | |
| max_output_tokens=max_output_tokens, | |
| temperature=temperature, | |
| ) | |
| text_out = getattr(resp, "output_text", None) | |
| if text_out is None: | |
| text_out = json.dumps(resp.model_dump(), ensure_ascii=False) | |
| usage = getattr(resp, "usage", None) | |
| return text_out, dict(usage) if usage else {} | |
| except AttributeError: | |
| pass # Responses API not available | |
| # Fall back to Chat Completions | |
| chat = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| max_tokens=max_output_tokens, | |
| temperature=temperature, | |
| ) | |
| text_out = chat.choices[0].message.content or "" | |
| usage = chat.usage | |
| return text_out, dict(usage) if usage else {} | |
| for attempt in range(retries): | |
| try: | |
| # Execute with timeout | |
| with ThreadPoolExecutor(max_workers=1) as executor: | |
| future = executor.submit(_do_call) | |
| try: | |
| return future.result(timeout=timeout) | |
| except FuturesTimeout: | |
| raise TimeoutError(f"API call timed out after {timeout}s") | |
| except Exception as e: | |
| last_error = e | |
| if attempt < retries - 1: | |
| sleep_time = (2 ** attempt) + random.uniform(0, 1) | |
| logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {sleep_time:.1f}s...") | |
| time.sleep(sleep_time) | |
| else: | |
| logger.error(f"All {retries} attempts failed. Last error: {e}") | |
| raise last_error or RuntimeError("Unknown error in API call") | |
| def _extract_json(text: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Robustly extract JSON from LLM output. | |
| Handles pure JSON, JSON in markdown blocks, and JSON embedded in text. | |
| """ | |
| text = text.strip() | |
| # Try direct parse first | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to extract from markdown code block | |
| if "```" in text: | |
| # Find content between ``` markers | |
| parts = text.split("```") | |
| for part in parts[1::2]: # Odd indices are inside code blocks | |
| # Remove optional language identifier | |
| lines = part.strip().split("\n") | |
| if lines[0].lower() in ("json", ""): | |
| content = "\n".join(lines[1:]) if len(lines) > 1 else lines[0] | |
| else: | |
| content = part | |
| try: | |
| return json.loads(content.strip()) | |
| except json.JSONDecodeError: | |
| continue | |
| # Try to find JSON object in text | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| snippet = text[start:end + 1] | |
| try: | |
| return json.loads(snippet) | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| # ---------------------------- | |
| # Node: Preprocess (feature extraction) | |
| # ---------------------------- | |
| def preprocess(state: MoEState) -> MoEState: | |
| """Extract features from the input to assist routing.""" | |
| text = state["user_text"] | |
| text_lower = text.lower() | |
| features = { | |
| # Code indicators | |
| "has_code_block": "```" in text, | |
| "has_code_keywords": any(kw in text_lower for kw in [ | |
| "def ", "function ", "class ", "import ", "const ", "let ", "var ", | |
| "error:", "traceback", "exception", "stack trace" | |
| ]), | |
| "mentions_programming": any(kw in text_lower for kw in [ | |
| "code", "debug", "bug", "fix", "refactor", "sql", "regex", "script", | |
| "python", "javascript", "java", "c++", "rust", "golang" | |
| ]), | |
| # Reasoning indicators | |
| "has_math_markers": any(kw in text_lower for kw in [ | |
| "calculate", "compute", "solve", "equation", "prove", "derive", | |
| "∑", "∫", "∂", "√" | |
| ]), | |
| "asks_for_analysis": any(kw in text_lower for kw in [ | |
| "analyze", "compare", "contrast", "evaluate", "assess", "pros and cons", | |
| "tradeoff", "trade-off" | |
| ]), | |
| "multi_step_indicators": any(kw in text_lower for kw in [ | |
| "step by step", "first", "then", "finally", "plan", "strategy" | |
| ]), | |
| # Complexity indicators | |
| "is_short": len(text) < 100, | |
| "is_long": len(text) > 500, | |
| "has_multiple_questions": text.count("?") > 1, | |
| # Simple task indicators | |
| "asks_for_summary": any(kw in text_lower for kw in [ | |
| "summarize", "summary", "tldr", "tl;dr", "brief", "shorten" | |
| ]), | |
| "asks_for_rewrite": any(kw in text_lower for kw in [ | |
| "rewrite", "rephrase", "paraphrase", "reword" | |
| ]), | |
| "asks_for_translation": any(kw in text_lower for kw in [ | |
| "translate", "translation", "in spanish", "in french", "in german", | |
| "to english", "to spanish" | |
| ]), | |
| } | |
| state["features"] = features | |
| return _log(state, "preprocess", **features) | |
| # ---------------------------- | |
| # Node: LLM Router | |
| # ---------------------------- | |
| def llm_route(state: MoEState) -> MoEState: | |
| """Use LLM to classify the request and select appropriate tier.""" | |
| router_cfg = MODEL_REGISTRY["router"] | |
| router_model = router_cfg["id"] | |
| user_text = state["user_text"] | |
| features = state.get("features", {}) | |
| # Build context hint from features | |
| hints = [] | |
| if features.get("has_code_block") or features.get("has_code_keywords"): | |
| hints.append("Input contains code or error traces.") | |
| if features.get("has_math_markers"): | |
| hints.append("Input contains mathematical notation.") | |
| if features.get("is_short") and features.get("asks_for_summary"): | |
| hints.append("This appears to be a simple summarization request.") | |
| hint_text = " ".join(hints) if hints else "" | |
| messages = [ | |
| {"role": "system", "content": ROUTER_SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"{hint_text}\n\nUser request:\n{user_text}" if hint_text else user_text}, | |
| ] | |
| t0 = time.perf_counter() | |
| try: | |
| out_text, usage = _call_openai_compatible( | |
| model=router_model, | |
| messages=messages, | |
| max_output_tokens=router_cfg["max_output_tokens"], | |
| temperature=router_cfg["temperature"], | |
| ) | |
| latency_ms = int((time.perf_counter() - t0) * 1000) | |
| except Exception as e: | |
| # Router failed - default to reasoning tier as safe fallback | |
| logger.error(f"Router call failed: {e}") | |
| state["selected_tier"] = "reasoning" | |
| state["route_confidence"] = 0.0 | |
| state["route_reason"] = f"Router failed ({e}); defaulting to reasoning" | |
| return _log(state, "llm_route", error=str(e), fallback="reasoning") | |
| # Parse response | |
| parsed = _extract_json(out_text) or {} | |
| tier = parsed.get("tier") | |
| confidence = parsed.get("confidence") | |
| reason = parsed.get("reason", "") | |
| # Validate tier | |
| valid_tiers = ("small", "code", "reasoning", "large") | |
| if tier not in valid_tiers: | |
| # Use features to make educated guess | |
| if features.get("has_code_block") or features.get("mentions_programming"): | |
| tier = "code" | |
| reason = f"Router output invalid ('{tier}'); inferred 'code' from features" | |
| elif features.get("has_math_markers") or features.get("asks_for_analysis"): | |
| tier = "reasoning" | |
| reason = f"Router output invalid; inferred 'reasoning' from features" | |
| else: | |
| tier = "reasoning" # Safe default | |
| reason = f"Router output invalid; defaulting to 'reasoning'" | |
| confidence = 0.5 | |
| # Validate confidence | |
| try: | |
| conf_f = float(confidence) | |
| conf_f = max(0.0, min(1.0, conf_f)) # Clamp to [0, 1] | |
| except (TypeError, ValueError): | |
| conf_f = 0.5 | |
| state["selected_tier"] = tier # type: ignore[assignment] | |
| state["route_confidence"] = conf_f | |
| state["route_reason"] = str(reason) | |
| return _log( | |
| state, | |
| "llm_route", | |
| router_model=router_model, | |
| latency_ms=latency_ms, | |
| usage=usage, | |
| selected_tier=tier, | |
| confidence=conf_f, | |
| reason=reason, | |
| raw_output=out_text[:200], | |
| ) | |
| # ---------------------------- | |
| # Node: Call Expert Model | |
| # ---------------------------- | |
| def call_expert(state: MoEState) -> MoEState: | |
| """Execute the selected expert model.""" | |
| tier: Tier = state["selected_tier"] | |
| cfg = MODEL_REGISTRY[tier] | |
| model_id = cfg["id"] | |
| # Use provided messages or construct from user_text | |
| messages = state.get("messages") | |
| if not messages: | |
| messages = [{"role": "user", "content": state["user_text"]}] | |
| t0 = time.perf_counter() | |
| try: | |
| out_text, usage = _call_openai_compatible( | |
| model=model_id, | |
| messages=messages, | |
| max_output_tokens=cfg["max_output_tokens"], | |
| temperature=cfg["temperature"], | |
| ) | |
| latency_ms = int((time.perf_counter() - t0) * 1000) | |
| except Exception as e: | |
| logger.error(f"Expert call failed: {e}") | |
| state["response_text"] = "" | |
| state["error"] = str(e) | |
| state["attempts"] = state.get("attempts", 0) + 1 | |
| return _log(state, "call_expert", error=str(e), tier=tier, model=model_id) | |
| state["response_text"] = out_text | |
| state["model_used"] = model_id | |
| state["tier_used"] = tier | |
| state["usage"] = usage | |
| state["latency_ms"] = latency_ms | |
| state["attempts"] = state.get("attempts", 0) + 1 | |
| state["error"] = None | |
| return _log( | |
| state, | |
| "call_expert", | |
| tier=tier, | |
| model=model_id, | |
| latency_ms=latency_ms, | |
| usage=usage, | |
| response_length=len(out_text), | |
| ) | |
| # ---------------------------- | |
| # Node: LLM Judge | |
| # ---------------------------- | |
| def llm_judge(state: MoEState) -> MoEState: | |
| """Evaluate response quality and decide whether to escalate.""" | |
| # Skip judging if disabled | |
| if not config.judge_enabled: | |
| state["needs_escalation"] = False | |
| state["escalation_reason"] = "judge disabled" | |
| return _log(state, "llm_judge", skipped=True, reason="judge disabled") | |
| # Skip if already at max tier - no point escalating | |
| if state.get("tier_used") == "large": | |
| state["needs_escalation"] = False | |
| state["escalation_reason"] = "already at max tier" | |
| return _log(state, "llm_judge", skipped=True, reason="already at max tier") | |
| # Skip if there was an error - should escalate | |
| if state.get("error"): | |
| state["needs_escalation"] = True | |
| state["escalation_reason"] = f"previous call failed: {state['error']}" | |
| return _log(state, "llm_judge", needs_escalation=True, reason="error recovery") | |
| judge_cfg = MODEL_REGISTRY["router"] | |
| user_text = state["user_text"] | |
| answer = state.get("response_text", "") | |
| tier_used = state.get("tier_used", "small") | |
| confidence = state.get("route_confidence", 0.5) | |
| # Truncate long answers to save tokens | |
| answer_truncated = answer[:2000] + ("..." if len(answer) > 2000 else "") | |
| messages = [ | |
| {"role": "system", "content": JUDGE_SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"User request:\n{user_text}\n\n" | |
| f"Model tier used: {tier_used}\n" | |
| f"Router confidence: {confidence:.2f}\n\n" | |
| f"Answer:\n{answer_truncated}" | |
| ), | |
| }, | |
| ] | |
| try: | |
| out_text, usage = _call_openai_compatible( | |
| model=judge_cfg["id"], | |
| messages=messages, | |
| max_output_tokens=150, | |
| temperature=0.0, | |
| ) | |
| except Exception as e: | |
| # Judge failed - use conservative heuristic | |
| logger.warning(f"Judge call failed: {e}") | |
| needs = tier_used == "small" and confidence < 0.6 | |
| state["needs_escalation"] = needs | |
| state["escalation_reason"] = f"judge failed ({e}); fallback heuristic" | |
| return _log(state, "llm_judge", error=str(e), needs_escalation=needs) | |
| parsed = _extract_json(out_text) or {} | |
| needs = parsed.get("needs_escalation") | |
| reason = parsed.get("reason", "") | |
| # Handle malformed output | |
| if not isinstance(needs, bool): | |
| # Conservative fallback: escalate cheap tiers with low confidence | |
| needs = (tier_used == "small" and confidence < 0.6) | |
| reason = f"judge parse failed; fallback policy (tier={tier_used}, conf={confidence:.2f})" | |
| logger.warning(f"Judge output malformed: {out_text[:100]}") | |
| state["needs_escalation"] = needs | |
| state["escalation_reason"] = reason | |
| return _log( | |
| state, | |
| "llm_judge", | |
| usage=usage, | |
| needs_escalation=needs, | |
| reason=reason, | |
| raw_output=out_text[:200], | |
| ) | |
| # ---------------------------- | |
| # Node: Escalate | |
| # ---------------------------- | |
| def escalate(state: MoEState) -> MoEState: | |
| """Move to a more capable model tier.""" | |
| current_tier: Tier = state.get("tier_used", state["selected_tier"]) | |
| new_tier = ESCALATION_MAP[current_tier] | |
| state["selected_tier"] = new_tier | |
| return _log( | |
| state, | |
| "escalate", | |
| from_tier=current_tier, | |
| to_tier=new_tier, | |
| reason=state.get("escalation_reason"), | |
| ) | |
| def should_escalate(state: MoEState) -> str: | |
| """Conditional edge: decide whether to escalate or finalize.""" | |
| attempts = state.get("attempts", 0) | |
| needs_escalation = state.get("needs_escalation", False) | |
| # Allow escalation only if under the limit | |
| if needs_escalation and attempts <= config.max_escalations: | |
| return "escalate" | |
| return "done" | |
| # ---------------------------- | |
| # Node: Finalize | |
| # ---------------------------- | |
| def finalize(state: MoEState) -> MoEState: | |
| """Calculate final cost and prepare output.""" | |
| total_cost = 0.0 | |
| total_tokens = 0 | |
| # Aggregate costs from all trace events | |
| for event in state.get("trace", []): | |
| usage = event.get("usage", {}) | |
| if not usage: | |
| continue | |
| tokens_in = usage.get("input_tokens") or usage.get("prompt_tokens") or 0 | |
| tokens_out = usage.get("output_tokens") or usage.get("completion_tokens") or 0 | |
| tokens = tokens_in + tokens_out | |
| total_tokens += tokens | |
| # Determine cost weight based on the node type | |
| node = event.get("node", "") | |
| if node in ("llm_route", "llm_judge"): | |
| cost_weight = MODEL_REGISTRY["router"]["relative_cost"] | |
| elif node == "call_expert": | |
| tier = event.get("tier", "small") | |
| cost_weight = MODEL_REGISTRY.get(tier, {}).get("relative_cost", 1.0) | |
| else: | |
| cost_weight = 1.0 | |
| total_cost += cost_weight * tokens | |
| state["estimated_cost_units"] = total_cost | |
| state["total_tokens"] = total_tokens | |
| return _log( | |
| state, | |
| "finalize", | |
| estimated_cost_units=total_cost, | |
| total_tokens=total_tokens, | |
| final_tier=state.get("tier_used"), | |
| final_model=state.get("model_used"), | |
| attempts=state.get("attempts"), | |
| route_confidence=state.get("route_confidence"), | |
| ) | |
| # ---------------------------- | |
| # Build LangGraph Application | |
| # ---------------------------- | |
| def build_app(): | |
| """Construct and compile the LangGraph workflow.""" | |
| graph = StateGraph(MoEState) | |
| # Add nodes | |
| graph.add_node("preprocess", preprocess) | |
| graph.add_node("llm_route", llm_route) | |
| graph.add_node("call_expert", call_expert) | |
| graph.add_node("llm_judge", llm_judge) | |
| graph.add_node("escalate", escalate) | |
| graph.add_node("finalize", finalize) | |
| # Define edges | |
| graph.set_entry_point("preprocess") | |
| graph.add_edge("preprocess", "llm_route") | |
| graph.add_edge("llm_route", "call_expert") | |
| graph.add_edge("call_expert", "llm_judge") | |
| graph.add_conditional_edges( | |
| "llm_judge", | |
| should_escalate, | |
| {"escalate": "escalate", "done": "finalize"}, | |
| ) | |
| graph.add_edge("escalate", "call_expert") | |
| graph.add_edge("finalize", END) | |
| return graph.compile() | |
| # ---------------------------- | |
| # Convenience Functions | |
| # ---------------------------- | |
| def run_query(user_text: str, messages: Optional[List[Dict[str, str]]] = None) -> MoEState: | |
| """ | |
| Run a single query through the MoE router. | |
| Args: | |
| user_text: The user's input/question | |
| messages: Optional full conversation history (OpenAI format) | |
| Returns: | |
| Final state with response and telemetry | |
| """ | |
| app = build_app() | |
| initial_state: MoEState = { | |
| "user_text": user_text, | |
| "attempts": 0, | |
| "trace": [], | |
| } | |
| if messages: | |
| initial_state["messages"] = messages | |
| return app.invoke(initial_state) | |
| def print_result(state: MoEState, verbose: bool = False) -> None: | |
| """Pretty-print the result of a query.""" | |
| print("\n" + "=" * 60) | |
| print("ROUTING DECISION") | |
| print("=" * 60) | |
| print(f"Selected Tier: {state.get('selected_tier')}") | |
| print(f"Confidence: {state.get('route_confidence', 0):.2f}") | |
| print(f"Reason: {state.get('route_reason', 'N/A')}") | |
| print("\n" + "-" * 60) | |
| print("EXECUTION") | |
| print("-" * 60) | |
| print(f"Final Tier: {state.get('tier_used')}") | |
| print(f"Model Used: {state.get('model_used')}") | |
| print(f"Attempts: {state.get('attempts')}") | |
| print(f"Latency: {state.get('latency_ms', 0)} ms") | |
| if state.get("needs_escalation") is not None: | |
| print(f"Escalated: {state.get('attempts', 1) > 1}") | |
| if state.get("escalation_reason"): | |
| print(f"Escalation: {state.get('escalation_reason')}") | |
| print("\n" + "-" * 60) | |
| print("COST ESTIMATE") | |
| print("-" * 60) | |
| print(f"Total Tokens: {state.get('total_tokens', 0)}") | |
| print(f"Cost Units: {state.get('estimated_cost_units', 0):.2f}") | |
| print("\n" + "-" * 60) | |
| print("RESPONSE") | |
| print("-" * 60) | |
| response = state.get("response_text", "") | |
| if len(response) > 1000 and not verbose: | |
| print(response[:1000] + "\n... [truncated]") | |
| else: | |
| print(response) | |
| if verbose: | |
| print("\n" + "-" * 60) | |
| print("TRACE") | |
| print("-" * 60) | |
| print(json.dumps(state.get("trace", []), indent=2, default=str)) | |
| print("=" * 60 + "\n") | |
| # ---------------------------- | |
| # CLI Entry Point | |
| # ---------------------------- | |
| def main(): | |
| """Command-line interface for testing the MoE router.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="MoE Router - Intelligent multi-model routing", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python moe_router.py "Summarize the benefits of exercise" | |
| python moe_router.py "Debug this Python traceback: ..." --verbose | |
| python moe_router.py --interactive | |
| """, | |
| ) | |
| parser.add_argument( | |
| "query", | |
| nargs="?", | |
| help="Query to process (omit for interactive mode)", | |
| ) | |
| parser.add_argument( | |
| "-v", "--verbose", | |
| action="store_true", | |
| help="Show full trace and response", | |
| ) | |
| parser.add_argument( | |
| "-i", "--interactive", | |
| action="store_true", | |
| help="Run in interactive mode", | |
| ) | |
| args = parser.parse_args() | |
| if args.interactive or args.query is None: | |
| print("MoE Router Interactive Mode") | |
| print("Type 'quit' or 'exit' to stop.\n") | |
| while True: | |
| try: | |
| query = input("You: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print("\nGoodbye!") | |
| break | |
| if not query: | |
| continue | |
| if query.lower() in ("quit", "exit", "q"): | |
| print("Goodbye!") | |
| break | |
| result = run_query(query) | |
| print_result(result, verbose=args.verbose) | |
| else: | |
| result = run_query(args.query) | |
| print_result(result, verbose=args.verbose) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment