-
-
Save sergeyklay/4f46a44cbbee566e9956ef1b9c14add5 to your computer and use it in GitHub Desktop.
LangChain/LangGraph `@tool` support for class-level methods
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import inspect | |
from typing import Callable, Literal, Optional | |
from langchain_core.tools import BaseTool, StructuredTool | |
from langchain_core.tools.base import ArgsSchema | |
def class_bound_tool( | |
*args, | |
name: Optional[str] = None, | |
description: Optional[str] = None, | |
return_direct: bool = False, | |
args_schema: Optional[ArgsSchema] = None, | |
infer_schema: bool = True, | |
response_format: Literal["content", "content_and_artifact"] = "content", | |
parse_docstring: bool = False, | |
error_on_invalid_docstring: bool = True, | |
) -> Callable: | |
all_decorator_attrs = { | |
"name": name, | |
"description": description, | |
"return_direct": return_direct, | |
"args_schema": args_schema, | |
"infer_schema": infer_schema, | |
"response_format": response_format, | |
"parse_docstring": parse_docstring, | |
"error_on_invalid_docstring": error_on_invalid_docstring, | |
} | |
def second_layer_wrapper(method_reference_when_with_args: Callable): | |
@functools.wraps(method_reference_when_with_args) | |
def actual_call(*args, **kwargs): | |
return method_reference_when_with_args(*args, **kwargs) | |
for key, value in all_decorator_attrs.items(): | |
setattr(actual_call, key, value) | |
actual_call.original_function = method_reference_when_with_args | |
return actual_call | |
if len(args) > 0 and args[0] is not None: | |
method_reference = args[0] | |
@functools.wraps(method_reference) | |
def actual_call(*args, **kwargs): | |
return method_reference(*args, **kwargs) | |
for key, value in all_decorator_attrs.items(): | |
setattr(actual_call, key, value) | |
return actual_call | |
else: # Means the decorator was called with arguments. | |
return second_layer_wrapper | |
def postprocess_tools(tools) -> list[StructuredTool]: | |
wrapped_tools = [] | |
# Support for both @tool-annotated functions and direct object instance (!) methods. | |
for tool in tools: | |
wrapped_tool = tool | |
if not isinstance(tool, BaseTool): | |
tool_description_source = tool | |
original_function = None | |
all_decorator_attrs = { | |
attr: getattr(tool, attr) | |
for attr in dir(tool) | |
if not attr.startswith("_") and attr not in ["original_function"] | |
} | |
if hasattr(tool, "original_function"): | |
original_function = tool.original_function | |
tool_description_source = _create_method_proxy(original_function) | |
if not inspect.iscoroutinefunction(original_function): | |
func = tool_description_source | |
coroutine = None | |
else: | |
func = None | |
coroutine = tool_description_source | |
wrapped_tool = StructuredTool.from_function( | |
func=func, coroutine=coroutine, **all_decorator_attrs | |
) | |
if original_function is not None: | |
if coroutine is not None: | |
wrapped_tool.coroutine = tool | |
else: | |
wrapped_tool.func = tool | |
wrapped_tools.append(wrapped_tool) | |
return wrapped_tools | |
def _create_method_proxy(bound_method): | |
original_sig = inspect.signature(bound_method) | |
# Remove the first parameter ('self'). | |
new_params = list(original_sig.parameters.values())[1:] | |
new_sig = original_sig.replace(parameters=new_params) | |
@functools.wraps(bound_method) | |
def proxy(*args, **kwargs): | |
return bound_method(*args, **kwargs) | |
proxy.__signature__ = new_sig | |
return proxy |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langgraph.prebuilt import create_react_agent | |
class SomeClass: | |
def class_instance_operation(): | |
return "success" | |
@class_bound_tool(description="Just call me", return_direct=True) # Also supports other usual `@tool` params. | |
def some_tool(self): | |
"""Call me little sunshine""" | |
return self.class_instance_operation() | |
tool_class_instance = SomeClass() | |
tools = [tool_class_instance.some_tool] | |
tools = postprocess_tools(tools) | |
llm = ... | |
llm_with_tools = llm.bind_tools(tools, strict=True) | |
# Must pass tools the second time, just for successful sanity check. | |
graph = create_react_agent(model=llm_with_tools, tools=tools) | |
state = graph.invoke({"messages": [{"role": "user", "content": "Use the tool"}]}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment