|
import os |
|
from typing import Optional |
|
|
|
import boto3 |
|
from langchain.chains import RetrievalQA |
|
from langchain.document_loaders import ConfluenceLoader |
|
from langchain.embeddings import BedrockEmbeddings |
|
from langchain.llms.bedrock import Bedrock |
|
from langchain.schema import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
|
|
|
|
class ConfluenceQA: |
|
def __init__( |
|
self, |
|
profile_name: str, |
|
): |
|
self.vectordb = None |
|
self.client = self._set_client(profile_name) |
|
self.llm = Bedrock( |
|
model_id="anthropic.claude-v2", |
|
client=self.client, |
|
model_kwargs={"max_tokens_to_sample": 1000}, |
|
) |
|
self.embedding = BedrockEmbeddings( |
|
model_id="amazon.titan-embed-text-v1", client=self.client |
|
) |
|
|
|
def _set_client( |
|
self, |
|
profile_name: str, |
|
region_name: Optional[str] = None, |
|
endpoint_url: Optional[str] = None, |
|
): |
|
try: |
|
if profile_name: |
|
session = boto3.Session(profile_name=profile_name) |
|
else: |
|
session = boto3.Session() |
|
|
|
client_params = {} |
|
if region_name: |
|
client_params["region_name"] = region_name |
|
if endpoint_url: |
|
client_params["endpoint_url"] = endpoint_url |
|
self.client = session.client("bedrock-runtime", **client_params) |
|
|
|
return self.client |
|
|
|
except Exception as e: |
|
raise ValueError( |
|
"Could not load credentials to authenticate with AWS client." |
|
) from e |
|
|
|
def load_confluence_documents( |
|
self, |
|
persist_directory: str, |
|
space_key: Optional[str] = None, |
|
username: Optional[str] = None, |
|
token: Optional[str] = None, |
|
max_pages: int = 2000, |
|
force_reload: bool = False, |
|
) -> list[Document]: |
|
if persist_directory and os.path.exists(persist_directory) and not force_reload: |
|
self.vectordb = Chroma( |
|
persist_directory=persist_directory, embedding_function=self.embedding |
|
) |
|
else: |
|
loader = ConfluenceLoader( |
|
url="https://treasure-data.atlassian.net/wiki", |
|
username=username, # "[email protected]" |
|
api_key=token, |
|
) |
|
documents = loader.load(space_key=space_key, max_pages=max_pages) |
|
# Default splitter for load_and_split |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=100, |
|
) |
|
docs = text_splitter.split_documents(documents) |
|
|
|
self.vectordb = Chroma.from_documents( |
|
documents=docs, |
|
embedding=self.embedding, |
|
persist_directory=persist_directory, |
|
) |
|
|
|
def retrieval_qa_chain(self): |
|
self.retriever = self.vectordb.as_retriever() |
|
self.qa = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever, return_source_documents=True) |
|
|
|
def answer_confluence(self, question: str) -> str: |
|
return self.qa.run(question) |