Last active
May 7, 2025 15:57
-
-
Save aurotripathy/8e04f9ab704af044d27c4a2ea5cb824b to your computer and use it in GitHub Desktop.
A canonical RAG sample application. Uses (furiosa-llm server + OpenAI embeddings + Chroma Vector DB + LangChain framework)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A canonical RAG sample application | |
# Uses (furiosa-llm server + OpenAI embeddings + Chroma Vector DB + LangChain framework) | |
# 100% cursor-generated code | |
# Needs a text file in the documents directory, | |
# I used the doc, https://gist.github.com/wey-gu/75d49362d011a0f0354d39e396404ba2 | |
from typing import List, Dict, Optional | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import TextLoader, DirectoryLoader | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
class SimpleRAG: | |
def __init__(self, | |
openai_api_key: Optional[str] = None, | |
embedding_model: str = "text-embedding-ada-002", | |
persist_directory: str = "chroma_db"): | |
""" | |
Initialize the RAG system. | |
Args: | |
openai_api_key: OpenAI API key (optional if set in environment) | |
embedding_model: Name of the embedding model to use | |
llm_model: Name of the language model to use | |
persist_directory: Directory to persist the vector database | |
""" | |
self.embeddings = OpenAIEmbeddings(model=embedding_model) | |
self.llm = ChatOpenAI( | |
model_name="EMPTY", | |
temperature=0.7, | |
openai_api_key="EMPTY", | |
openai_api_base="http://localhost:8000/v1" | |
) | |
self.persist_directory = persist_directory | |
self.vectorstore = None | |
# Initialize text splitter | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len, | |
) | |
def load_documents(self, directory_path: str) -> None: | |
""" | |
Load text documents from a directory and create embeddings. | |
Args: | |
directory_path: Path to directory containing text files | |
""" | |
loader = DirectoryLoader(directory_path, glob="**/*.txt", loader_cls=TextLoader) | |
documents = loader.load() | |
texts = self.text_splitter.split_documents(documents) | |
# Create and persist vector store | |
self.vectorstore = Chroma.from_documents( | |
documents=texts, | |
embedding=self.embeddings, | |
persist_directory=self.persist_directory | |
) | |
self.vectorstore.persist() | |
def query(self, question: str, k: int = 4) -> Dict: | |
""" | |
Query the RAG system with a question. | |
Args: | |
question: The question to ask | |
k: Number of relevant documents to retrieve | |
Returns: | |
Dict containing the answer and source documents | |
""" | |
if not self.vectorstore: | |
raise ValueError("No documents loaded. Please load documents first.") | |
# Create retriever | |
retriever = self.vectorstore.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": k} | |
) | |
# Create QA chain | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True | |
) | |
# Get response | |
result = qa_chain({"query": question}) | |
return { | |
"answer": result["result"], | |
"sources": [doc.page_content for doc in result["source_documents"]] | |
} | |
# Example usage | |
if __name__ == "__main__": | |
# Initialize RAG system | |
rag = SimpleRAG() | |
# Load documents from a directory | |
rag.load_documents("documents") | |
# question = "What is the main topic of the documents?" | |
question = "What did Paul do in 2015?" | |
result = rag.query(question) | |
print(f"Question: {question}") | |
print(f"\nAnswer: {result['answer']}") | |
print("\nSources:") | |
for i, source in enumerate(result['sources'], 1): | |
print(f"\n{i}. {source[:200]}...") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment