Last active
July 11, 2025 07:53
-
-
Save duyixian1234/7272241b789d4091a2537a2f887b5dd6 to your computer and use it in GitHub Desktop.
Ai Agent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from collections.abc import Callable | |
import sys | |
from typing import Annotated | |
from function_schema import get_function_schema | |
from openai import OpenAI | |
from openai.types.chat import ChatCompletionMessageParam | |
class Chat: | |
def __init__( | |
self, | |
*, | |
model: str, | |
client: OpenAI | None = None, | |
system: str | None = None, | |
tools: dict[str, tuple[Callable, dict]] | None = None, | |
): | |
self.messages: list[ChatCompletionMessageParam] = [ | |
{"role": "system", "content": system} | |
] | |
self.client = client or OpenAI() | |
self.model = model | |
self.tools = tools or {} | |
self.tool_definitions = [ | |
{"type": "function", "function": schema} | |
for _, schema in self.tools.values() | |
] | |
def __call__(self, content: str) -> str: | |
self.messages.append({"role": "user", "content": content}) | |
response = self.client.chat.completions.create( | |
messages=self.messages, | |
model=self.model, | |
tools=self.tool_definitions, | |
) | |
message = response.choices[0].message | |
while message.tool_calls: | |
self.messages.append(message.model_dump()) | |
for tool_call in message.tool_calls: | |
tool, _ = self.tools[tool_call.function.name] | |
tool_result = tool(**json.loads(tool_call.function.arguments)) | |
self.messages.append( | |
{ | |
"role": "tool", | |
"content": str(tool_result), | |
"tool_call_id": tool_call.id, | |
} | |
) | |
message = ( | |
self.client.chat.completions.create( | |
messages=self.messages, | |
model=self.model, | |
tools=self.tool_definitions, | |
) | |
.choices[0] | |
.message | |
) | |
message_content = message.content or "" | |
self.messages.append({"role": "assistant", "content": message_content}) | |
return message_content | |
def get_weather(city: Annotated[str, "city"]) -> str: | |
"""Get the weather for a given city.""" | |
return "Sunny, 20 degrees Celsius" | |
def get_location() -> str: | |
"""Get the current location""" | |
return "Beijing" | |
import random | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logging.getLogger("httpx").disabled = True | |
def get_random_number(lower: int = 0, upper: int = 1000) -> int: | |
"""Get a random number between lower and upper""" | |
logging.info(f"Generating a random number between {lower} and {upper}") | |
return random.randint(lower, upper) | |
def square(x: int) -> int: | |
"""Return the square of x""" | |
logging.info(f"Calculating the square of {x}") | |
return x * x | |
chat = Chat( | |
model="glm-4", | |
system="回答用户的问题,简短明了;优先使用提供的工具函数", | |
tools={ | |
"get_random_number": ( | |
get_random_number, | |
get_function_schema(get_random_number), | |
), | |
"square": (square, get_function_schema(square)), | |
}, | |
) | |
print(chat("生成一个五位的随机数并计算平方")) | |
chat = Chat(model="glm-4", system="回答用户的问题,简短明了") | |
print(chat("生成一个五位的随机数并计算平方")) | |
# chat = Chat( | |
# model="glm-4", | |
# tools={ | |
# "get_weather": (get_weather, get_function_schema(get_weather)), | |
# "get_location": (get_location, get_function_schema(get_location)), | |
# }, | |
# ) | |
# print(chat("当前位置的天气怎么样?")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment