Skip to content

Instantly share code, notes, and snippets.

@lululau
Last active February 11, 2025 01:52
Show Gist options
  • Save lululau/e11a1cd363e5becebbf83c25ac9f7244 to your computer and use it in GitHub Desktop.
Save lululau/e11a1cd363e5becebbf83c25ac9f7244 to your computer and use it in GitHub Desktop.
ChatQwen -- Qwen Client for LangChain (resolution of JSON over-quoting of tool call arguments)
from __future__ import annotations
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _create_usage_metadata
import json
from typing import (
Any,
Dict,
Mapping,
Optional,
Union,
cast,
)
import openai
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from json import JSONDecodeError
from langchain_core.exceptions import OutputParserException
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.utils.json import parse_partial_json
from pydantic import Field
class ChatQwen(ChatOpenAI):
max_json_unquote_recursion: int = Field(default=3, description="Tool Arguments JSON 最大去引用解析递归次数")
def _create_chat_result(
self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None,
) -> ChatResult:
generations = []
response_dict = (
response if isinstance(response, dict) else response.model_dump()
)
# Sometimes the AI Model calling will get error, we should raise it.
# Otherwise, the next code 'choices.extend(response["choices"])'
# will throw a "TypeError: 'NoneType' object is not iterable" error
# to mask the true error. Because 'response["choices"]' is None.
if response_dict.get("error"):
raise ValueError(response_dict.get("error"))
token_usage = response_dict.get("usage")
for res in response_dict["choices"]:
message = self._convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = _create_usage_metadata(token_usage)
generation_info = generation_info or {}
generation_info["finish_reason"] = (
res.get("finish_reason")
if res.get("finish_reason") is not None
else generation_info.get("finish_reason")
)
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(message=message, generation_info=generation_info)
generations.append(gen)
llm_output = {
"token_usage": token_usage,
"model_name": response_dict.get("model", self.model_name),
"system_fingerprint": response_dict.get("system_fingerprint", ""),
}
if isinstance(response, openai.BaseModel) and getattr(
response, "choices", None
):
message = response.choices[0].message # type: ignore[attr-defined]
if hasattr(message, "parsed"):
generations[0].message.additional_kwargs["parsed"] = message.parsed
if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal
return ChatResult(generations=generations, llm_output=llm_output)
def _convert_dict_to_message(self, _dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict.get("role")
name = _dict.get("name")
id_ = _dict.get("id")
if role == "user":
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(self.parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
if audio := _dict.get("audio"):
additional_kwargs["audio"] = audio
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
name=name,
id=id_,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role in ("system", "developer"):
if role == "developer":
additional_kwargs = {"__openai_role__": role}
else:
additional_kwargs = {}
return SystemMessage(
content=_dict.get("content", ""),
name=name,
id=id_,
additional_kwargs=additional_kwargs,
)
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=cast(str, _dict.get("tool_call_id")),
additional_kwargs=additional_kwargs,
name=name,
id=id_,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
def parse_tool_call(
self,
raw_tool_call: dict[str, Any],
*,
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> Optional[dict[str, Any]]:
"""Parse a single tool call.
Args:
raw_tool_call: The raw tool call to parse.
partial: Whether to parse partial JSON. Default is False.
strict: Whether to allow non-JSON-compliant strings.
Default is False.
return_id: Whether to return the tool call id. Default is True.
Returns:
The parsed tool call.
Raises:
OutputParserException: If the tool call is not valid JSON.
"""
if "function" not in raw_tool_call:
return None
if partial:
try:
function_args = parse_partial_json(
raw_tool_call["function"]["arguments"], strict=strict
)
except (JSONDecodeError, TypeError): # None args raise TypeError
return None
else:
try:
function_args = self.parse_over_quoted_json(raw_tool_call["function"]["arguments"], strict=strict)
except JSONDecodeError as e:
msg = (
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
raise OutputParserException(msg) from e
parsed = {
"name": raw_tool_call["function"]["name"] or "",
"args": function_args or {},
}
if return_id:
parsed["id"] = raw_tool_call.get("id")
parsed = create_tool_call(**parsed) # type: ignore
return parsed
def parse_over_quoted_json(self, text: str, strict: bool = False) -> Any:
"""Parse over quoted JSON.
"""
decoded_text = text
for i in range(self.max_json_unquote_recursion):
try:
decoded_text = json.loads(decoded_text, strict=strict)
if not isinstance(decoded_text, str):
return decoded_text
except JSONDecodeError as e:
if i > 0:
return decoded_text
raise e
raise OutputParserException("Failed to parse over quoted JSON")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment