Last active
June 20, 2024 18:48
-
-
Save mrchief/ebb2cb8104800df3e06005104474e8d7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Chains and agents """ | |
import os | |
import re | |
from typing import Optional | |
import pandas as pd | |
from langchain.agents import AgentExecutor, ZeroShotAgent | |
from langchain.agents.agent_types import AgentType | |
from langchain.chains.conversation.memory import ConversationBufferMemory | |
from langchain.chains.llm import LLMChain | |
from langchain.tools import StructuredTool, Tool | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.example_selectors import SemanticSimilarityExampleSelector | |
from langchain_core.prompts import ( | |
ChatPromptTemplate, | |
FewShotPromptTemplate, | |
MessagesPlaceholder, | |
PromptTemplate, | |
SystemMessagePromptTemplate, | |
) | |
from langchain_experimental.agents.agent_toolkits import create_csv_agent | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from openai import OpenAI as Openai | |
from pydantic import BaseModel, Field | |
from src.core.cloud_storage import RAW, pull_project_folder | |
from src.core.rag import IndexLLM | |
from src.utils.agents import ( | |
AGENT_SYSTEM_PREFIX, | |
CUSTOM_AGENT_FORMAT, | |
CUSTOM_PREFIX, | |
CUSTOM_SUFFIX, | |
apologize, | |
greets, | |
) | |
from src.utils.examples_creation import ( | |
build_examples_from_dataframe, | |
get_key_from_file_path, | |
) | |
from src.utils.settings import OPENAI_API_KEY, RAW_PATH, configure_logger | |
logger = configure_logger("Chain agent") | |
class StructuredFileAgentInput(BaseModel): | |
question: str = Field() | |
class AgentSynthesizer: | |
"""Main class for agent synthesizer""" | |
def __init__(self, index: IndexLLM): | |
"""Init | |
Args: | |
index (IndexLLM): An IndexLLM object that | |
contains llamatool. | |
""" | |
self.index = index | |
self.llama_tool = None | |
self.agent_chain = None | |
self.memory = None | |
self.set_default_tools() | |
def set_default_tools(self): | |
self.greetings_tool = StructuredTool.from_function( | |
name="Greets", | |
func=greets, | |
description="use this if there is no question to answer, invite the user to ask a question", | |
args_schema=BaseModel, | |
return_direct=True, | |
) | |
self.apologize_tool = StructuredTool.from_function( | |
name="Apologize", | |
func=apologize, | |
description="use this if the other tools are not suitable to respond a question", | |
args_schema=BaseModel, | |
return_direct=True, | |
) | |
def load_index_if_not_loaded( | |
self, path: Optional[str], index: Optional[IndexLLM] | |
) -> None: | |
"""Load index if it is not loaded | |
Args: | |
path (Optional[str]): The path of the project index | |
index (Optional[IndexLLM]): _description_ | |
Raises: | |
err: Exception at loading the index | |
""" | |
if path is not None: | |
project_name = os.path.basename(path) | |
else: | |
project_name = "" | |
if index is not None: | |
if hasattr(index, "index") and index.index_id == project_name: | |
logger.info(f"Index {index.index_id} is already set") | |
else: | |
try: | |
self.index = index | |
index.load_index(path) | |
logger.info(f"Loaded index: {index.index_id}") | |
except Exception as err: | |
logger.error( | |
f"Failed to load index from path: {path}", exc_info=err | |
) | |
raise err | |
else: | |
try: | |
self.index.load_index(path=path) | |
logger.info(f"Loaded index: {self.index.index_id}") | |
except Exception as err: | |
logger.error(f"Failed to load index from path: {path}", exc_info=err) | |
raise err | |
self._download_raw_from_project_name(project_name) | |
files_dict = self.index.get_docs_filename_and_context_from_metadata() | |
self.structured_filenames = list(files_dict.keys()) | |
self.structured_user_contexts = list(files_dict.values()) | |
self.llama_tool = self.index.create_agent_tool( | |
name="Read", | |
description="use this to search for information using private documents and retrieve augmented information from there", | |
return_redirect=True, | |
) | |
def _download_raw_from_project_name(self, name: str): | |
"""Download raw data from project name | |
Args: | |
name (str): The name of the project (index) to be downloaded | |
""" | |
raw_path = RAW_PATH + name | |
if not os.path.exists(raw_path): | |
os.mkdir(path=raw_path) | |
logger.info(f"Pulling raw data for {name}") | |
pull_project_folder(RAW, name, raw_path) | |
else: | |
logger.info(f"Already loaded {RAW} content into {raw_path}") | |
def get_metadata_by_filetype(self, extension: str = "csv") -> tuple: | |
"""Get metadata by file type | |
Args: | |
extension (str, optional): The file extension. Defaults to "csv". | |
Returns: | |
tuple: A tuple containing the filenames and the user contexts | |
""" | |
if len(self.structured_filenames) != len(self.structured_user_contexts): | |
raise ValueError("Files and user contexts are not the same length") | |
filenames = [] | |
u_contexts = [] | |
for i, filename in enumerate(self.structured_filenames): | |
if filename.endswith(extension): | |
filenames.append(filename) | |
u_contexts.append(self.structured_user_contexts[i]) | |
return filenames, u_contexts | |
def resolve_which_is_the_relevant_file(self, question: str) -> int: | |
"""Resolve which is the most relevant file to answer a given question | |
it uses two lists in the agent object: | |
- structured_filenames | |
- structured_user_contexts | |
Args: | |
question (str): The question to be answered | |
Returns: | |
int: The index of the most relevant file | |
""" | |
str_contexts = "" | |
i = 0 | |
filenames, contexts = self.get_metadata_by_filetype() | |
logger.info( | |
f"Resolving the most relevant file using contexts from {len(filenames)} files" | |
) | |
for context in contexts: | |
str_contexts += f"""'{context}' is the context of the file number {i}, """ | |
i += 1 | |
client = Openai(api_key=OPENAI_API_KEY) | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
temperature=0.01, | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{ | |
"role": "user", | |
"content": f"Taking into consideration files tagged with an ordinal number and with the following contexts provided by a human: '{str_contexts}'", | |
}, | |
{ | |
"role": "user", | |
"content": f"Which file is the most relevant to answer to the question '{question}'? To respond, provide only the number of the most relevant file", | |
}, | |
], | |
) | |
response_message = response.choices[0].message.content | |
logger.info(f"original message: {response_message}") | |
return int(re.findall(r"\d+", response_message)[0]) | |
def _define_few_shot_examples(self, file_path: str, question: str) -> str: | |
"""Define few shot examples for an specific question using a file | |
Args: | |
file_path (str): File path to the csv file | |
question (str): Question from the user | |
Returns: | |
str: A prompt prefix with the few shot examples | |
""" | |
table_df = pd.read_csv(file_path) | |
file_key = get_key_from_file_path(file_path) | |
examples = build_examples_from_dataframe(df=table_df, key=file_key) | |
example_selector = SemanticSimilarityExampleSelector.from_examples( | |
examples=examples, | |
embeddings=OpenAIEmbeddings(), | |
vectorstore_cls=FAISS, | |
k=2, | |
input_keys=["input"], | |
) | |
system_prefix = AGENT_SYSTEM_PREFIX | |
few_shot_prompt = FewShotPromptTemplate( | |
example_selector=example_selector, | |
example_prompt=PromptTemplate.from_template( | |
"User input: {input}\nAgent output:{output}" | |
), | |
input_variables=["input", "output"], | |
prefix=system_prefix, | |
suffix="User input: {input}\n", | |
) | |
full_prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate(prompt=few_shot_prompt), | |
("human", "{input}"), | |
MessagesPlaceholder("agent_scratchpad"), | |
] | |
) | |
prompt_val = full_prompt.invoke( | |
{ | |
"input": question, | |
"top_k": 2, | |
"dialect": "python_REPL_ast", | |
"agent_scratchpad": [], | |
} | |
) | |
return prompt_val.to_string() | |
def create_structured_agent(self, question: str): | |
"""Create a structured agent from the index | |
Args: | |
question (str): A question from the user | |
Returns: | |
An agent object | |
""" | |
file_number = self.resolve_which_is_the_relevant_file(question=question) | |
csv_files, csv_user_contexts = self.get_metadata_by_filetype() | |
file = csv_files[file_number] | |
logger.info( | |
f"The file number is {file_number} using it from " | |
+ f"{len(self.structured_filenames)} files : " | |
+ f"[{csv_files[file_number]} : {csv_user_contexts[file_number]}]" | |
) | |
raw_path = os.path.join(RAW_PATH, self.index.index_id) | |
prompt_val = self._define_few_shot_examples( | |
file_path=os.path.join(raw_path, file), question=question | |
) | |
agent_csv = create_csv_agent( | |
ChatOpenAI(temperature=0.01, model="gpt-3.5-turbo"), | |
os.path.join(raw_path, file), | |
verbose=False, | |
agent_type=AgentType.OPENAI_FUNCTIONS, | |
prefix=prompt_val, | |
) | |
return agent_csv | |
def initialize_memory_conversation(self) -> ConversationBufferMemory: | |
"""Initialize memory conversation | |
Returns: | |
ConversationBufferMemory: A buffer memory for a conversation | |
""" | |
self._config_memory_conversation() | |
return self.memory | |
def query_agent(self, input_text: str) -> tuple: | |
"""Run a query within the agent | |
Args: | |
input_text (str): A query text from the user | |
Returns: | |
tuple: Both, the output and the intermediate steps | |
from the response object | |
""" | |
if self.agent_chain is None: | |
self._config_memory_conversation(question=input_text) | |
response = self.agent_chain({"input": input_text}) | |
logger.info(f"Memory: {self.memory.chat_memory}") | |
return ( | |
response["output"], | |
response["intermediate_steps"], | |
) | |
def _enlist_tools(self, *args) -> None: | |
"""Set tools to be used""" | |
tools = list(args) | |
self.tools = tools | |
def _config_memory_conversation(self, question: str): | |
"""Configure conversation memory""" | |
self._enlist_tools( | |
self.greetings_tool, | |
self.apologize_tool, | |
self.llama_tool, | |
Tool.from_function( | |
func=self.create_structured_agent(question=question).invoke, | |
name="Query table", | |
description="useful for when you require to get specific data running queries for data sources that contain metrics in a tabular structure", | |
), | |
) | |
prompt = ZeroShotAgent.create_prompt( | |
tools=self.tools, | |
prefix=CUSTOM_PREFIX, | |
suffix=CUSTOM_SUFFIX, | |
format_instructions=CUSTOM_AGENT_FORMAT, | |
input_variables=["input", "chat_history", "agent_scratchpad"], | |
) | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", output_key="output" | |
) | |
llm = ChatOpenAI( | |
temperature=0.01, openai_api_key=OPENAI_API_KEY, model=self.index.llm.model | |
) | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
agent = ZeroShotAgent(llm_chain=llm_chain, tools=self.tools) | |
self.agent_chain = AgentExecutor.from_agent_and_tools( | |
agent=agent, | |
tools=self.tools, | |
memory=self.memory, | |
handle_parsing_errors=True, | |
return_intermediate_steps=True, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment