Created
November 15, 2024 03:34
-
-
Save phamson02/17976ecb266d7732aacbb2aca4b72b3a to your computer and use it in GitHub Desktop.
Triage agent with Phidata
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pyright: reportPrivateImportUsage=false | |
import re | |
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union | |
from phi.agent import Agent | |
from phi.embedder.azure_openai import AzureOpenAIEmbedder | |
from phi.knowledge.pdf import PDFUrlKnowledgeBase | |
from phi.memory.agent import AgentRun | |
from phi.model.azure import AzureOpenAIChat | |
from phi.model.message import Message | |
from phi.utils.log import logger | |
from phi.utils.pprint import pprint_run_response | |
from phi.vectordb.qdrant import Qdrant | |
from phi.model.response import ModelResponse | |
from phi.workflow import RunEvent, RunResponse, Workflow | |
from settings import settings | |
vector_db = Qdrant( | |
collection="test_thai_recipes", | |
url=settings.qdrant_url, | |
embedder=AzureOpenAIEmbedder( | |
dimensions=3072, | |
api_version=settings.azure_openai_api_version, | |
azure_endpoint=settings.azure_openai_endpoint, | |
azure_deployment=settings.azure_openai_embedding_deployment, | |
api_key=settings.azure_openai_api_key, | |
), | |
prefer_grpc=True, | |
) | |
# Create a knowledge base from a PDF | |
knowledge_base = PDFUrlKnowledgeBase( | |
urls=["https://phi-public.s3.amazonaws.com/recipes/ThaiRecipes.pdf"], | |
vector_db=vector_db, | |
) | |
# Comment out after first run as the knowledge base is loaded | |
# knowledge_base.load() | |
class TriageAgent(Agent): | |
def _update_memory( | |
self, | |
run_messages: List[Message], | |
message: Optional[Union[str, List, Dict, Message]] = None, | |
*, | |
stream: bool = False, | |
images: Optional[Sequence[Union[str, Dict]]] = None, | |
messages: Optional[Sequence[Union[Dict, Message]]] = None, | |
stream_intermediate_steps: bool = False, | |
**kwargs: Any, | |
) -> Iterator[RunResponse]: | |
stream_agent_response = stream and self.streamable | |
system_message, _, _ = self.get_messages_for_run( | |
message=message, images=images, messages=messages, **kwargs | |
) | |
# Build a list of messages that belong to this particular run | |
self.run_response.messages = run_messages | |
self.run_response.metrics = self._aggregate_metrics_from_run_messages( | |
run_messages | |
) | |
# Add the system message to the memory | |
if system_message is not None: | |
self.memory.add_system_message( | |
system_message, system_message_role=self.system_message_role | |
) | |
# Add the user messages and model response messages to memory | |
self.memory.add_messages(messages=(run_messages)) | |
# Create an AgentRun object to add to memory | |
agent_run = AgentRun(response=self.run_response) | |
if message is not None: | |
user_message_for_memory: Optional[Message] = None | |
if isinstance(message, str): | |
user_message_for_memory = Message( | |
role=self.user_message_role, content=message | |
) # type: ignore | |
elif isinstance(message, Message): | |
user_message_for_memory = message | |
if user_message_for_memory is not None: | |
agent_run.message = user_message_for_memory | |
# Update the memories with the user message if needed | |
if ( | |
self.memory.create_user_memories | |
and self.memory.update_user_memories_after_run | |
): | |
self.memory.update_memory( | |
input=user_message_for_memory.get_content_string() | |
) | |
elif messages is not None and len(messages) > 0: | |
for _m in messages: | |
_um = None | |
if isinstance(_m, Message): | |
_um = _m | |
elif isinstance(_m, dict): | |
try: | |
_um = Message.model_validate(_m) | |
except Exception as e: | |
logger.warning(f"Failed to validate message: {e}") | |
else: | |
logger.warning(f"Unsupported message type: {type(_m)}") | |
continue | |
if _um: | |
if agent_run.messages is None: | |
agent_run.messages = [] | |
agent_run.messages.append(_um) | |
if ( | |
self.memory.create_user_memories | |
and self.memory.update_user_memories_after_run | |
): | |
self.memory.update_memory(input=_um.get_content_string()) | |
else: | |
logger.warning("Unable to add message to memory") | |
# Add AgentRun to memory | |
self.memory.add_run(agent_run) | |
# Update the session summary if needed | |
if ( | |
self.memory.create_session_summary | |
and self.memory.update_session_summary_after_run | |
): | |
self.memory.update_summary() | |
# 7. Save session to storage | |
self.write_to_storage() | |
# 8. Save output to file if save_response_to_file is set | |
self.save_run_response_to_file(message=message) | |
# 9. Set the run_input | |
if message is not None: | |
if isinstance(message, str): | |
self.run_input = message | |
elif isinstance(message, Message): | |
self.run_input = message.to_dict() | |
else: | |
self.run_input = message | |
elif messages is not None: | |
self.run_input = [ | |
m.to_dict() if isinstance(m, Message) else m for m in messages | |
] | |
# Log Agent Run | |
self.log_agent_run() | |
logger.debug( | |
f"*********** Agent Run End: {self.run_response.run_id} ***********" | |
) | |
if stream_intermediate_steps: | |
yield RunResponse( | |
run_id=self.run_id, | |
session_id=self.session_id, | |
agent_id=self.agent_id, | |
content=self.run_response.content, | |
tools=self.run_response.tools, | |
messages=self.run_response.messages, | |
event=RunEvent.run_completed.value, | |
) | |
# -*- Yield final response if not streaming so that run() can get the response | |
if not stream_agent_response: | |
yield self.run_response | |
class TriageWorkflow(Workflow): | |
knowledge_agent: Agent = Agent( | |
agent_id="knowledge_agent", | |
name="Knowledge Agent", | |
role="Answers cooking questions in short conversational style", | |
provider=AzureOpenAIChat( | |
id="knowledge_agent", | |
api_version=settings.azure_openai_api_version, | |
azure_endpoint=settings.azure_openai_endpoint, | |
azure_deployment=settings.azure_openai_deployment, | |
api_key=settings.azure_openai_api_key, | |
temperature=settings.temperature, | |
max_tokens=200, | |
), | |
instructions=[ | |
"Go through the retreived relevant information and summarize it in conversational style", | |
"Keep your answers in one or two sentences only in conversational style", | |
"Do not return Markdown formatted text, keep it like a dialogue", | |
], | |
tools=None, | |
search_knowledge=True, | |
knowledge_base=knowledge_base, | |
markdown=False, | |
) # type: ignore | |
triage_agent: TriageAgent = TriageAgent( | |
agent_id="triage_agent", | |
name="Triage Agent", | |
provider=AzureOpenAIChat( | |
id="triage_agent", | |
api_version=settings.azure_openai_api_version, | |
azure_endpoint=settings.azure_openai_endpoint, | |
azure_deployment=settings.azure_openai_deployment, | |
api_key=settings.azure_openai_api_key, | |
temperature=settings.temperature, | |
), | |
team=[knowledge_agent], | |
instructions=[ | |
"Keep your answers in one or two sentences only in conversational style", | |
], | |
num_history_responses=3, | |
add_chat_history_to_messages=True, | |
show_tool_calls=True, | |
) # type: ignore | |
def transfer_task_to_agent( | |
self, | |
member_agent: Agent, | |
task_description: str, | |
expected_output: str, | |
extra_data: Optional[str] = None, | |
) -> Iterator[RunResponse]: | |
# Update the member agent session_data to include leader_session_id, leader_agent_id and leader_run_id | |
if member_agent.session_data is None: | |
member_agent.session_data = {} | |
member_agent.session_data["leader_session_id"] = self.session_id | |
member_agent.session_data["leader_agent_id"] = self.triage_agent.agent_id | |
member_agent.session_data["leader_run_id"] = self.run_id | |
# -*- Run the agent | |
member_agent_messages = f"{task_description}\n\nThe expected output is: {expected_output}\n\nAdditional information: {extra_data}" | |
member_agent_run_response_iter: Iterator[RunResponse] = member_agent.run( | |
member_agent_messages, stream=True | |
) | |
member_agent_run_response = next(member_agent_run_response_iter) | |
yield member_agent_run_response | |
yield from member_agent_run_response_iter | |
# update the leader agent session_data to include member_session_id, member_agent_id | |
member_agent_info = { | |
"session_id": member_agent_run_response.session_id, | |
"agent_id": member_agent_run_response.agent_id, | |
} | |
# Update the leader agent session_data to include member_agent_info | |
if self.session_data is None: | |
self.session_data = {"members": [member_agent_info]} | |
else: | |
if "members" not in self.session_data: | |
self.session_data["members"] = [] | |
# Check if member_agent_info is already in the list | |
if member_agent_info not in self.session_data["members"]: | |
self.session_data["members"].append(member_agent_info) | |
def run(self, query: str) -> Iterator[RunResponse]: # type: ignore | |
print(f"Running triage workflow with query: {query}") | |
response: Iterator[RunResponse] = self.triage_agent.run( | |
message=query, | |
stream=True, | |
stream_intermediate_steps=True, | |
) | |
transfer_message = None | |
for message in response: | |
if message.event == RunEvent.tool_call_started.value: | |
transfer_message = message.content | |
assert isinstance(transfer_message, str) | |
match = re.search( | |
r"transfer_task_to_knowledge_agent\(task_description=(.*?),\s*expected_output=(.*?)\)", | |
transfer_message, | |
) | |
if match: | |
task_description = match.group(1).strip() | |
expected_output = match.group(2).strip() | |
result = self.transfer_task_to_agent( | |
member_agent=self.knowledge_agent, | |
task_description=task_description, | |
expected_output=expected_output, | |
) | |
model_response = ModelResponse(content="") | |
for message_chunk in result: | |
if message_chunk.content is not None: | |
model_response.content += message_chunk.content | |
model_response.created_at = message_chunk.created_at | |
yield message_chunk | |
assert self.knowledge_agent.run_response.messages is not None | |
yield from self.triage_agent._update_memory( | |
run_messages=self.knowledge_agent.run_response.messages[1:], | |
message=query, | |
stream=True, | |
) | |
break | |
elif (message.event == RunEvent.run_response.value) and ( | |
not transfer_message | |
): | |
yield message | |
workflow = TriageWorkflow( | |
workflow_id="triage_workflow", | |
session_id="triage_workflow", | |
debug_mode=False, | |
) | |
print("Running triage workflow") | |
response = workflow.run("Hi") | |
pprint_run_response(response) | |
response = workflow.run("What is the recipe for Pad Thai?") | |
pprint_run_response(response) | |
response = workflow.run("Tell me more about the history of the dish") | |
pprint_run_response(response) | |
response = workflow.run("Is there any health benefits that come with this dish?") | |
pprint_run_response(response) | |
response = workflow.run("What are the ingredients for Tom Yum Goong?") | |
pprint_run_response(response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment