Created
November 7, 2024 05:25
-
-
Save lemassykoi/bf094e9ba3de94c0b844f63844b4470f to your computer and use it in GitHub Desktop.
Ollama Langchain Structured Chat Agent en Français
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script de gestion d'interaction avec un modèle linguistique d'intelligence artificielle (LLM). | |
Ce script permet l'échange bidirectionnel de messages entre un utilisateur humain et une intelligence artificielle, | |
en utilisant le modèle linguistique de son choix, avec Ollama. | |
Il incorpore des fonctionnalités telles que la recherche web, la résolution d'informations en temps réel, | |
et l'intégration de mémoires internes pour améliorer les capacités de réponse et d'apprentissage continu du système. | |
Auteurs: | |
[Clément PAPPALARDO] | |
Date de création: [2024/11/03] | |
Dernière mise à jour: [2024/11/07] | |
Licence: MIT License | |
Fonctionnalités clés: | |
- Interaction en direct avec l'utilisateur | |
- Gestion des sessions et conversations | |
- Utilisation d'outils intégrés pour la recherche web, les informations météorologiques et plus encore | |
- Sauvegarde et chargement de l'historique des conversations | |
""" | |
import os | |
import time | |
import logging | |
import requests | |
from colorama import Back, Style, Fore | |
from langchain.agents import AgentExecutor, create_structured_chat_agent | |
from langchain.agents.agent_toolkits import create_retriever_tool | |
from langchain_chroma import Chroma | |
from langchain_ollama import ChatOllama, OllamaEmbeddings | |
from langchain_community.utilities import GoogleSerperAPIWrapper | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.agent_toolkits.load_tools import load_tools | |
from langchain_community.retrievers import WikipediaRetriever | |
from langchain_core.messages import BaseMessage | |
from langchain_core.tools import Tool, StructuredTool | |
from langchain_core.tools import tool | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain.globals import set_debug | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain.tools.render import render_text_description | |
from langgraph.graph import add_messages | |
from typing import TypedDict, Annotated, Sequence | |
import datetime | |
from typing import List | |
from pydantic import BaseModel, Field | |
from langchain_core.runnables import ConfigurableFieldSpec | |
from typing import Dict, Tuple | |
import atexit | |
import pickle | |
from langchain_core.prompts.chat import ChatPromptTemplate | |
os.environ["SERPER_API_KEY"] = "xxx" | |
os.environ["WOLFRAM_ALPHA_APPID"] = 'xxx' | |
os.environ["ANONYMIZED_TELEMETRY"] = 'False' | |
GOOGLE_API_KEY = 'xxx' | |
GOOGLE_CSE_ID = 'xxx' | |
LOCAL_DIR = './save/' | |
os.makedirs(LOCAL_DIR, exist_ok=True) | |
LOG_SAVE_PATH = os.path.join(LOCAL_DIR, f"{str(datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))}_tests.log") | |
LOG_FILE = LOG_SAVE_PATH | |
# LANGCHAIN GLOBAL | |
set_debug(False) | |
# AGENT VERBOSE | |
is_verbose = False | |
## LOGGING LEVEL (INFO/DEBUG) | |
is_DEBUG = False | |
### LOGGER BEGIN | |
class CustomFormatter(logging.Formatter): | |
grey = "\x1b[38;20m" | |
yellow = "\x1b[33;20m" | |
red = "\x1b[31;20m" | |
bold_red = "\x1b[31;1m" | |
reset = "\x1b[0m" | |
F_LightGreen = "\x1b[92m" | |
F_LightBlue = "\x1b[94m" | |
underline = "\x1b[4m" | |
base_format = "%(asctime)s - %(message)s" | |
datefmt = '%d/%m/%Y | %H:%M:%S' | |
FORMATS = { | |
logging.DEBUG: underline + grey + base_format + reset, | |
logging.INFO: F_LightGreen + base_format + reset, | |
logging.WARNING: yellow + base_format + reset, | |
logging.ERROR: underline + red + base_format + reset, | |
logging.CRITICAL: bold_red + base_format + reset | |
} | |
def format(self, record): | |
# Obtenir la largeur du terminal | |
terminal_width = os.get_terminal_size().columns | |
# Créer les informations de fonction et ligne | |
func_info = f"[{record.funcName}:{record.lineno}]" | |
# Calculer l'espacement nécessaire | |
base_message_length = 20 + len(record.getMessage()) # 20 pour la longueur approximative de la date | |
padding = terminal_width - base_message_length - len(func_info) - 16 # Ajustement | |
# Assurer que l'espacement soit positif (sinon on met à zéro) | |
padding = max(padding, 0) | |
# Formater dynamiquement avec l'espacement calculé | |
adjusted_format = f"%(asctime)s - %(message)s{' ' * padding}{func_info} - %(levelname)s" | |
# Sélectionner le bon format selon le niveau de log | |
log_fmt = self.FORMATS.get(record.levelno) | |
# Appliquer le nouveau format avec l'espacement calculé | |
log_fmt = log_fmt.replace(self.base_format, adjusted_format) | |
# Créer et appliquer le format final en incluant `datefmt` | |
formatter = logging.Formatter(log_fmt, datefmt=self.datefmt) | |
return formatter.format(record) | |
# Configuration du logger GLOBAL | |
logger = logging.getLogger() | |
logging.getLogger("langchain_community.retrievers.web_research").setLevel(logging.ERROR) | |
logging.getLogger("httpx").setLevel(logging.ERROR) | |
logger.setLevel(logging.DEBUG) | |
# CONSOLE HANDLER | |
ch = logging.StreamHandler() | |
ch.setFormatter(CustomFormatter()) | |
# Définir le niveau de log CONSOLE en fonction de is_DEBUG | |
if is_DEBUG: | |
ch.setLevel(logging.DEBUG) | |
else: | |
ch.setLevel(logging.INFO) | |
# FILE HANDLER | |
fh = logging.FileHandler(LOG_FILE, encoding='utf-8') | |
fh.setFormatter(logging.Formatter(fmt="%(asctime)s - %(message)s")) | |
fh.setLevel(logging.DEBUG) | |
# ADD HANDLERS TO GLOBAL LOGGER | |
logger.addHandler(ch) | |
logger.addHandler(fh) | |
### LOGGER END | |
print(f"\n{Style.RESET_ALL}") | |
logger.info(Back.RED + 'Loading Start...' + Style.RESET_ALL) | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], add_messages] | |
class InMemoryHistory(BaseChatMessageHistory, BaseModel): | |
"""In memory implementation of chat message history.""" | |
messages: List[BaseMessage] = Field(default_factory=list) | |
def add_messages(self, messages: List[BaseMessage]) -> None: | |
"""Add a list of messages to the store""" | |
self.messages.extend(messages) | |
def clear(self) -> None: | |
self.messages = [] | |
class WebSearchArgs(BaseModel): | |
query: str | |
# VARIABLES | |
ollama_base_url = 'http://127.0.0.1:11434' | |
current_model_name = "qwen2.5:7b-instruct-q6_K" | |
current_embed_model = "snowflake-arctic-embed:latest" | |
model_temperature = 0.3 | |
# Store Variable | |
store: Dict[Tuple[str, str], InMemoryHistory] = {} | |
# Filename for Text MEMO to incorporate to LLM Knowledge | |
file_path = 'vault.txt' | |
# Define the filename for history | |
HISTORY_FILE = "chat_history.pkl" | |
def get_num_ctx(model: str) -> int: | |
"""Use this function to retrieve the num_ctx associated with chat llm""" | |
model_data = requests.post(ollama_base_url + '/api/show', json={"name": model}).json() | |
data = model_data.get('model_info') | |
for key, value in data.items(): | |
if 'context_length' in key: | |
return int(value) | |
logger.error(f'No num_ctx found for chat model name {current_model_name}') | |
def get_embed_num_ctx(model: str) -> int: | |
"""Use this function to retrieve the num_ctx associated with embedding llm""" | |
model_data = requests.post(ollama_base_url + '/api/show', json={"name": model}).json() | |
data = model_data.get('model_info') | |
for key, value in data.items(): | |
if 'embedding_length' in key: | |
return int(value) | |
logger.error(f'No num_ctx found for embedding model name {current_embed_model}') | |
model_ctx = get_num_ctx(current_model_name) | |
embed_ctx = get_embed_num_ctx(current_embed_model) | |
llm_func = ChatOllama( | |
model = current_model_name, | |
base_url = ollama_base_url, | |
format = "json", | |
num_ctx = model_ctx, | |
temperature = model_temperature, | |
verbose = is_verbose, | |
) | |
embed_func = OllamaEmbeddings( | |
model = current_embed_model, | |
base_url = ollama_base_url, | |
) | |
def get_session_history(session_id: str, thread_id: str) -> BaseChatMessageHistory: | |
""" Use this function to retrieve history for the current user. Looks for bot session_id and thread_id """ | |
if (session_id, thread_id) not in store: | |
logger.warning('conversation not in store -> creating store for current thread and session') | |
store[(session_id, thread_id)] = InMemoryHistory() | |
return store[(session_id, thread_id)] | |
# Fonction pour charger l'historique depuis un fichier pickle, si disponible | |
def load_history(): | |
""" Use this function to load history from file """ | |
global store | |
if os.path.exists(HISTORY_FILE): | |
with open(HISTORY_FILE, "rb") as file: | |
store = pickle.load(file) | |
logger.warning(Fore.YELLOW + f"Found and loaded History file: {HISTORY_FILE}" + Style.RESET_ALL) | |
logger.debug(store) | |
else: | |
logger.warning(Fore.RED + f"No history file found: {HISTORY_FILE}" + Style.RESET_ALL) | |
# Fonction pour sauvegarder `store` dans un fichier pickle | |
def save_history_on_exit(): | |
""" Use this function to save history to file before exit """ | |
with open(HISTORY_FILE, "wb") as file: | |
pickle.dump(store, file) | |
logger.warning(f"Historique sauvegardé dans {HISTORY_FILE}") | |
logger.debug(store) | |
def get_current_datetime(): | |
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
logger.info(Back.RED + 'Funct Model :' + Style.RESET_ALL + ' ' + f'{current_model_name}') | |
logger.info(Back.RED + 'Funct num_ctx :' + Style.RESET_ALL + ' ' + f'{str(model_ctx)}') | |
logger.info(Back.RED + 'Embed Model :' + Style.RESET_ALL + ' ' + f'{current_embed_model}') | |
logger.info(Back.RED + 'Embed num_ctx :' + Style.RESET_ALL + ' ' + f'{str(embed_ctx)}') | |
logger.info(Back.RED + 'Temperature :' + Style.RESET_ALL + ' ' + f'{str(model_temperature)}') | |
# Vault Embedding | |
logger.info(f'Reading content of file {file_path}') | |
text_loader = TextLoader(file_path, encoding='UTF-8') | |
docs = text_loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=embed_ctx, chunk_overlap=128) | |
splits = text_splitter.split_documents(docs) | |
# Initialize Embedding | |
logger.info('Embedding...') | |
chroma_vector_store_docs = Chroma.from_documents( | |
documents = splits, | |
embedding = embed_func, | |
) | |
# Build Retrievers | |
logger.info('Retrievers...') | |
retriever_docs = chroma_vector_store_docs.as_retriever(search_kwargs={'k': 3}) | |
retriever_wikipedia_fr = WikipediaRetriever(top_k_results=3, lang='fr', name='French Wikipedia Retriever') | |
retriever_wikipedia_en = WikipediaRetriever(top_k_results=3, lang='en', name='English Wikipedia Retriever') | |
google_search_fr = GoogleSerperAPIWrapper(gl='fr', hl='fr') | |
google_news_fr = GoogleSerperAPIWrapper(gl='fr', hl='fr', type='news') | |
# Tools Part | |
logger.info('Tools...') | |
def getPrice(input) -> float: | |
url = "https://api.coincap.io/v2/assets/" + input.lower() | |
response = requests.get(url).json() | |
price = response["data"]["priceUsd"] | |
return price | |
tool_get_crypto_price = [ | |
Tool( | |
name = "Obtenir le cours d'une cryptomonnaie, en USD.", | |
func = getPrice, | |
description = "Utilisez cette fonction pour obtenir le prix d'une cryptomonnaie donnée à partir de la saisie de l'utilisateur. Renvoie le prix en USD.", | |
) | |
] | |
complex_tool_google_search_fr = [ | |
StructuredTool.from_function( | |
name = "Recherche Web sur Google Serper en Français (recherche globale)", | |
description = "Utilisez cette fonction lorsque vous devez effectuer une recherche en ligne sur Internet General, en français. Renvoie un résultat JSON.", | |
args_schema = WebSearchArgs, | |
return_direct = False, | |
verbose = is_verbose, | |
parse_docstring = True, | |
func = google_search_fr.results, ## with metadatas | |
) | |
] | |
complex_tool_google_news_fr = [ | |
StructuredTool.from_function( | |
name = "Recherche Web sur Google Serper en Français (recherche parmi les Actualités)", | |
description = "Utilisez cette fonction lorsque vous devez effectuer une recherche en ligne sur Internet à propos des Actualités, en langue française. Renvoie un résultat JSON.", | |
args_schema = WebSearchArgs, | |
return_direct = False, | |
verbose = is_verbose, | |
parse_docstring = True, | |
func = google_news_fr.results, ## with metadatas | |
) | |
] | |
tool_query_wiki_fr = create_retriever_tool( | |
retriever_wikipedia_fr, | |
"Recherche sur Wikipedia FR", | |
"Recherche et renvoie des documents depuis le Wikipedia Français." | |
) | |
tool_parse_txt = create_retriever_tool( | |
retriever_docs, | |
"Recherche dans le Memo Interne", | |
"Recherche et renvoie des documents à partir du mémo interne. Utile lorsque vous devez répondre à des questions sur Jouques, Claude ou sur un sujet lié à l'assistance informatique locale.", | |
) | |
@tool("obtenir_date_et_heure_courante", return_direct=False) | |
def obtenir_date_et_heure_courante(input: str) -> str: | |
""" | |
Utilisez cette fonction pour obtenir la date et l'heure actuelle, selon le fuseau horaire local du système en cours d'exécution. | |
Date et heure actuelles au format ISO. Vous devrez les formater dans un format compréhensible par l'homme. | |
Arguments: | |
Aucun. | |
""" | |
return datetime.datetime.now().isoformat() # '%Y-%m-%dT%H:%M:%S' | |
all_tools = (load_tools(['wolfram-alpha'], llm=llm_func)) | |
all_tools += complex_tool_google_search_fr | |
all_tools += complex_tool_google_news_fr | |
all_tools += tool_get_crypto_price | |
all_tools.append(obtenir_date_et_heure_courante) | |
all_tools.append(tool_parse_txt) | |
all_tools.append(tool_query_wiki_fr) | |
# Save STORE to disk at script exit | |
atexit.register(save_history_on_exit) | |
## PROMPT ## | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", | |
"""Respond to the human as helpfully and accurately as possible. You have access to the following tools: | |
{tools} | |
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |
Valid "action" values: "Final Answer" or {tool_names} | |
Provide only ONE action per $JSON_BLOB, as shown: | |
``` | |
{{ | |
"action": $TOOL_NAME, | |
"action_input": $INPUT | |
}} | |
``` | |
Follow this format: | |
Question: input question to answer | |
Thought: consider previous and subsequent steps | |
Action: | |
``` | |
$JSON_BLOB | |
``` | |
Observation: action result | |
... (repeat Thought/Action/Observation N times) | |
Thought: I know what to respond | |
Action: | |
``` | |
{{ | |
"action": "Final Answer", | |
"action_input": "Final response to human" | |
}} | |
Begin! Reminder to ALWAYS respond with a valid JSON blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB``` then Observation:\n | |
Nota Bene: The current ISO Date/Time is: {current_date_time}\n | |
Additional Directives (if any): {additional_orders} | |
"""), | |
("placeholder", "{chat_history}"), | |
("human", "{input}\n{agent_scratchpad}\n(reminder to respond in a JSON blob no matter what)"), | |
] | |
) | |
additional_orders = "Your name is Jarvis, inspired by Iron Man\'s Assistant." | |
if is_verbose is True: | |
print(type(prompt)) | |
print(prompt) | |
## Edit Prompt | |
prompt = prompt.partial( | |
tools = render_text_description(all_tools), | |
tool_names = ", ".join([t.name for t in all_tools]), | |
) | |
## AGENT ## | |
structured_agent = create_structured_chat_agent(llm_func, all_tools, prompt) | |
# CONVERSATION ID | |
thread_id = "TESTING" | |
# USER ID | |
session_id = 123456789 | |
# Charger l'historique au démarrage du script | |
logger.info("Loading Chat History from disk, if any") | |
load_history() | |
def construct_chain(): | |
agent_executor = AgentExecutor( | |
agent = structured_agent, | |
tools = all_tools, | |
handle_parsing_errors = False, | |
return_intermediate_steps = True, | |
verbose = is_verbose | |
) | |
return RunnableWithMessageHistory( | |
agent_executor, | |
get_session_history, | |
input_messages_key = "input", | |
history_messages_key = "chat_history", | |
history_factory_config=[ | |
ConfigurableFieldSpec( | |
id="session_id", | |
annotation=str, | |
name="User ID", | |
description="Unique identifier for the user.", | |
default="", | |
is_shared=True, | |
), | |
ConfigurableFieldSpec( | |
id="thread_id", | |
annotation=str, | |
name="Conversation ID", | |
description="Unique identifier for the conversation.", | |
default="", | |
is_shared=True, | |
), | |
], | |
) | |
def chain_to_llm(query: str, session_id: str, thread_id: str) -> str: | |
logger.debug(f'LLM Query : {query}') | |
config = { | |
"configurable": { | |
"session_id": session_id, | |
"thread_id": thread_id, | |
} | |
} | |
try: | |
start_time = time.perf_counter() | |
agent_with_chat_history = construct_chain() | |
response = agent_with_chat_history.invoke( | |
{ | |
"input": query, | |
"current_date_time": get_current_datetime(), | |
"additional_orders": additional_orders, | |
}, | |
config = config, | |
) | |
stop_time = time.perf_counter() | |
elapsed_time = stop_time - start_time | |
answer = (response["output"] + f"\nDuration : {elapsed_time:.2f} seconds") | |
except Exception as e: | |
logger.error(f'=== ERROR : {e}') | |
answer = "Je suis désolé, j'ai rencontré une erreur :( Essayez à nouveau svp." | |
pass | |
logger.debug(f'OUTGOING Answer for {session_id} :\n{answer}') | |
return answer | |
logger.info(Back.RED + 'Loading Done.' + Style.RESET_ALL) | |
logger.info('Running...') | |
while True: | |
try: | |
query = input(f"\n{Fore.BLUE}Humain : ") | |
print(f"\n{Style.RESET_ALL}") | |
except KeyboardInterrupt: | |
print("\nCTRL+C - Exiting\n") | |
exit(0) | |
print(Fore.GREEN + "\nI.A. : " + chain_to_llm(query, session_id, thread_id) + Style.RESET_ALL) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment