Created
May 21, 2023 18:03
-
-
Save ger86/7207977bd4f52311f77920287fef4c15 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from dotenv import load_dotenv | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.vectorstores import Chroma | |
from langchain.llms import GPT4All, LlamaCpp | |
import os | |
import argparse | |
load_dotenv() | |
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") | |
persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
model_type = os.environ.get('MODEL_TYPE') | |
model_path = os.environ.get('MODEL_PATH') | |
model_n_ctx = os.environ.get('MODEL_N_CTX') | |
from constants import CHROMA_SETTINGS | |
def main(query): | |
# Parse the command line arguments | |
args = parse_arguments() | |
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) | |
retriever = db.as_retriever() | |
# activate/deactivate the streaming StdOut callback for LLMs | |
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] | |
# Prepare the LLM | |
match model_type: | |
case "LlamaCpp": | |
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) | |
case "GPT4All": | |
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) | |
case _default: | |
print(f"Model {model_type} not supported!") | |
exit; | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) | |
# Get the answer from the chain | |
res = qa(query) | |
answer, docs = res['result'], [] if args.hide_source else res['source_documents'] | |
# Print the result | |
# print("\n\n> Question:") | |
# print(query) | |
print("\n> Answer:") | |
print(answer) | |
# Print the relevant sources used for the answer | |
# for document in docs: | |
# print("\n> " + document.metadata["source"] + ":") | |
# print(document.page_content) | |
return answer | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, ' | |
'using the power of LLMs.') | |
parser.add_argument("--hide-source", "-S", action='store_true', | |
help='Use this flag to disable printing of source documents used for answers.') | |
parser.add_argument("--mute-stream", "-M", | |
action='store_true', | |
help='Use this flag to disable the streaming StdOut callback for LLMs.') | |
parser.add_argument("--query", "-Q", required=True, | |
help='Use this argument to pass your question/query.') | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = parse_arguments() | |
main(args.query) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment