Last active
February 11, 2025 01:52
-
-
Save lululau/e11a1cd363e5becebbf83c25ac9f7244 to your computer and use it in GitHub Desktop.
ChatQwen -- Qwen Client for LangChain (resolution of JSON over-quoting of tool call arguments)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from langchain_openai import ChatOpenAI | |
from langchain_openai.chat_models.base import _create_usage_metadata | |
import json | |
from typing import ( | |
Any, | |
Dict, | |
Mapping, | |
Optional, | |
Union, | |
cast, | |
) | |
import openai | |
from langchain_core.messages import ( | |
AIMessage, | |
BaseMessage, | |
ChatMessage, | |
FunctionMessage, | |
HumanMessage, | |
SystemMessage, | |
ToolMessage, | |
) | |
from langchain_core.output_parsers.openai_tools import ( | |
make_invalid_tool_call, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatResult | |
from json import JSONDecodeError | |
from langchain_core.exceptions import OutputParserException | |
from langchain_core.messages.tool import tool_call as create_tool_call | |
from langchain_core.utils.json import parse_partial_json | |
from pydantic import Field | |
class ChatQwen(ChatOpenAI): | |
max_json_unquote_recursion: int = Field(default=3, description="Tool Arguments JSON 最大去引用解析递归次数") | |
def _create_chat_result( | |
self, | |
response: Union[dict, openai.BaseModel], | |
generation_info: Optional[Dict] = None, | |
) -> ChatResult: | |
generations = [] | |
response_dict = ( | |
response if isinstance(response, dict) else response.model_dump() | |
) | |
# Sometimes the AI Model calling will get error, we should raise it. | |
# Otherwise, the next code 'choices.extend(response["choices"])' | |
# will throw a "TypeError: 'NoneType' object is not iterable" error | |
# to mask the true error. Because 'response["choices"]' is None. | |
if response_dict.get("error"): | |
raise ValueError(response_dict.get("error")) | |
token_usage = response_dict.get("usage") | |
for res in response_dict["choices"]: | |
message = self._convert_dict_to_message(res["message"]) | |
if token_usage and isinstance(message, AIMessage): | |
message.usage_metadata = _create_usage_metadata(token_usage) | |
generation_info = generation_info or {} | |
generation_info["finish_reason"] = ( | |
res.get("finish_reason") | |
if res.get("finish_reason") is not None | |
else generation_info.get("finish_reason") | |
) | |
if "logprobs" in res: | |
generation_info["logprobs"] = res["logprobs"] | |
gen = ChatGeneration(message=message, generation_info=generation_info) | |
generations.append(gen) | |
llm_output = { | |
"token_usage": token_usage, | |
"model_name": response_dict.get("model", self.model_name), | |
"system_fingerprint": response_dict.get("system_fingerprint", ""), | |
} | |
if isinstance(response, openai.BaseModel) and getattr( | |
response, "choices", None | |
): | |
message = response.choices[0].message # type: ignore[attr-defined] | |
if hasattr(message, "parsed"): | |
generations[0].message.additional_kwargs["parsed"] = message.parsed | |
if hasattr(message, "refusal"): | |
generations[0].message.additional_kwargs["refusal"] = message.refusal | |
return ChatResult(generations=generations, llm_output=llm_output) | |
def _convert_dict_to_message(self, _dict: Mapping[str, Any]) -> BaseMessage: | |
"""Convert a dictionary to a LangChain message. | |
Args: | |
_dict: The dictionary. | |
Returns: | |
The LangChain message. | |
""" | |
role = _dict.get("role") | |
name = _dict.get("name") | |
id_ = _dict.get("id") | |
if role == "user": | |
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name) | |
elif role == "assistant": | |
# Fix for azure | |
# Also OpenAI returns None for tool invocations | |
content = _dict.get("content", "") or "" | |
additional_kwargs: Dict = {} | |
if function_call := _dict.get("function_call"): | |
additional_kwargs["function_call"] = dict(function_call) | |
tool_calls = [] | |
invalid_tool_calls = [] | |
if raw_tool_calls := _dict.get("tool_calls"): | |
additional_kwargs["tool_calls"] = raw_tool_calls | |
for raw_tool_call in raw_tool_calls: | |
try: | |
tool_calls.append(self.parse_tool_call(raw_tool_call, return_id=True)) | |
except Exception as e: | |
invalid_tool_calls.append( | |
make_invalid_tool_call(raw_tool_call, str(e)) | |
) | |
if audio := _dict.get("audio"): | |
additional_kwargs["audio"] = audio | |
return AIMessage( | |
content=content, | |
additional_kwargs=additional_kwargs, | |
name=name, | |
id=id_, | |
tool_calls=tool_calls, | |
invalid_tool_calls=invalid_tool_calls, | |
) | |
elif role in ("system", "developer"): | |
if role == "developer": | |
additional_kwargs = {"__openai_role__": role} | |
else: | |
additional_kwargs = {} | |
return SystemMessage( | |
content=_dict.get("content", ""), | |
name=name, | |
id=id_, | |
additional_kwargs=additional_kwargs, | |
) | |
elif role == "function": | |
return FunctionMessage( | |
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_ | |
) | |
elif role == "tool": | |
additional_kwargs = {} | |
if "name" in _dict: | |
additional_kwargs["name"] = _dict["name"] | |
return ToolMessage( | |
content=_dict.get("content", ""), | |
tool_call_id=cast(str, _dict.get("tool_call_id")), | |
additional_kwargs=additional_kwargs, | |
name=name, | |
id=id_, | |
) | |
else: | |
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type] | |
def parse_tool_call( | |
self, | |
raw_tool_call: dict[str, Any], | |
*, | |
partial: bool = False, | |
strict: bool = False, | |
return_id: bool = True, | |
) -> Optional[dict[str, Any]]: | |
"""Parse a single tool call. | |
Args: | |
raw_tool_call: The raw tool call to parse. | |
partial: Whether to parse partial JSON. Default is False. | |
strict: Whether to allow non-JSON-compliant strings. | |
Default is False. | |
return_id: Whether to return the tool call id. Default is True. | |
Returns: | |
The parsed tool call. | |
Raises: | |
OutputParserException: If the tool call is not valid JSON. | |
""" | |
if "function" not in raw_tool_call: | |
return None | |
if partial: | |
try: | |
function_args = parse_partial_json( | |
raw_tool_call["function"]["arguments"], strict=strict | |
) | |
except (JSONDecodeError, TypeError): # None args raise TypeError | |
return None | |
else: | |
try: | |
function_args = self.parse_over_quoted_json(raw_tool_call["function"]["arguments"], strict=strict) | |
except JSONDecodeError as e: | |
msg = ( | |
f"Function {raw_tool_call['function']['name']} arguments:\n\n" | |
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. " | |
f"Received JSONDecodeError {e}" | |
) | |
raise OutputParserException(msg) from e | |
parsed = { | |
"name": raw_tool_call["function"]["name"] or "", | |
"args": function_args or {}, | |
} | |
if return_id: | |
parsed["id"] = raw_tool_call.get("id") | |
parsed = create_tool_call(**parsed) # type: ignore | |
return parsed | |
def parse_over_quoted_json(self, text: str, strict: bool = False) -> Any: | |
"""Parse over quoted JSON. | |
""" | |
decoded_text = text | |
for i in range(self.max_json_unquote_recursion): | |
try: | |
decoded_text = json.loads(decoded_text, strict=strict) | |
if not isinstance(decoded_text, str): | |
return decoded_text | |
except JSONDecodeError as e: | |
if i > 0: | |
return decoded_text | |
raise e | |
raise OutputParserException("Failed to parse over quoted JSON") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment