-
-
Save bitsnaps/4a255327ee9507d2625cd9d3742445a0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, jsonify, request | |
import requests | |
import PyPDF2 | |
import tempfile | |
import pickle | |
import retrying | |
from langchain.llms import OpenAI | |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
from langchain.docstore.document import Document | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.vectorstores.faiss import FAISS | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.prompts import PromptTemplate #Imports | |
import tiktoken | |
import hashlib | |
import os | |
import urllib.parse | |
app = Flask(__name__) #Initiate the app | |
# In a real application, you would want to use a more secure method for generating and storing tokens | |
TOKEN = "" | |
openai_api_key = "sk-" | |
k = 3 | |
def remove_substring(url): | |
try: | |
url = url[:-4] | |
nameoffile = url.replace("https://", "") | |
return nameoffile | |
except: | |
return url | |
@retrying.retry(retry_on_exception=lambda x: isinstance(x, requests.exceptions.RequestException), stop_max_attempt_number=5, wait_exponential_multiplier=1000, wait_exponential_max=10000) | |
def generate_hash(url): | |
try: | |
# Encode URL as bytes | |
url_bytes = url.encode('utf-8') | |
# Generate hash object using SHA-256 algorithm | |
hash_object = hashlib.sha256(url_bytes) | |
# Convert hash object to hexadecimal string | |
hex_dig = hash_object.hexdigest() | |
# Split hash string into 4 parts of equal length | |
split_len = len(hex_dig) // 4 | |
hash_parts = [ | |
hex_dig[i:i + split_len] for i in range(0, len(hex_dig), split_len) | |
] | |
# Join hash parts with dashes and return as a string | |
return '-'.join(hash_parts) | |
except Exception as e: | |
print(f"Error generating hash: {e}") | |
@retrying.retry(retry_on_exception=lambda x: isinstance(x, requests.exceptions.RequestException), stop_max_attempt_number=5, wait_exponential_multiplier=1000, wait_exponential_max=10000) | |
def num_tokens_from_string(string: str, encoding_name: str) -> int: | |
"""Returns the number of tokens in a text string.""" | |
encoding = tiktoken.get_encoding(encoding_name) | |
num_tokens = len(encoding.encode(string)) | |
return num_tokens | |
# Decorator function to check if the request contains a valid token | |
def require_token(f): | |
def wrapper(*args, **kwargs): | |
token = request.headers.get('Authorization') | |
if not token or token != f'Token {TOKEN}': | |
return jsonify({"error": "Unauthorized"}), 401 | |
return f(*args, **kwargs) | |
wrapper.__name__ = f.__name__ | |
return wrapper | |
@app.route('/', methods=["GET"]) #Home page | |
@require_token | |
@retrying.retry(retry_on_exception=lambda x: isinstance(x, requests.exceptions.RequestException), stop_max_attempt_number=5, wait_exponential_multiplier=1000, wait_exponential_max=10000) | |
def home(): | |
return "<h1>Home for 'insert api name here' api.</h1>" | |
template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). | |
If you don't know the answer, just say that you don't know. Don't try to make up an answer. | |
ALWAYS return a list of "SOURCES" part in your answer. | |
QUESTION: {question} | |
========= | |
{context} | |
========= | |
FINAL ANSWER:""" | |
PROMPT = PromptTemplate(template=template, | |
input_variables=["context", "question"]) | |
@retrying.retry(retry_on_exception=lambda x: isinstance( | |
x, requests.exceptions.RequestException), | |
stop_max_attempt_number=5, | |
wait_exponential_multiplier=1000, | |
wait_exponential_max=10000) | |
def get_pdf_text(pdf_url): | |
with tempfile.NamedTemporaryFile(suffix=".pdf") as pdf_file: | |
response = requests.get(pdf_url) | |
pdf_file.write(response.content) | |
pdf_file.seek(0) | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
for page_number in range(len(pdf_reader.pages)): | |
page = pdf_reader.pages[page_number] | |
yield Document(page_content=page.extract_text(), | |
metadata={"source": page_number + 1}) | |
def get_file_text(file_url): | |
_, ext = os.path.splitext(file_url) | |
if ext.lower() == ".pdf": | |
with tempfile.NamedTemporaryFile(suffix=".pdf") as pdf_file: | |
response = requests.get(file_url) | |
pdf_file.write(response.content) | |
pdf_file.seek(0) | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
for page_number in range(len(pdf_reader.pages)): | |
page = pdf_reader.pages[page_number] | |
yield Document(page_content=page.extract_text(), | |
metadata={"source": file_url, "page_number": page_number + 1}) | |
elif ext.lower() == ".txt": | |
with tempfile.NamedTemporaryFile(suffix=".txt") as txt_file: | |
response = requests.get(file_url) | |
txt_file.write(response.content) | |
txt_file.seek(0) | |
text = txt_file.read() | |
yield Document(page_content=text, | |
metadata={"source": file_url}) | |
else: | |
raise ValueError("Unsupported file type") | |
@app.route('/v1/api/pdf_to_answer', methods=["POST"]) #Home page | |
@require_token | |
@retrying.retry(retry_on_exception=lambda x: isinstance( | |
x, requests.exceptions.RequestException), | |
stop_max_attempt_number=5, | |
wait_exponential_multiplier=1000, | |
wait_exponential_max=10000) | |
def pdf_to_answer(): | |
json_data = request.get_json() | |
question = json_data.get('question') | |
pdf_url = json_data.get('pdf_url') | |
file_name = generate_hash(pdf_url) | |
k = 4 | |
if not all([question, pdf_url, file_name]): | |
return jsonify({"error": "Missing required parameters"}), 400 | |
search_index = None | |
chain = load_qa_with_sources_chain(OpenAI(temperature=0, | |
openai_api_key=openai_api_key), | |
chain_type="stuff", | |
prompt=PROMPT) | |
try: | |
with open(file_name, "rb") as f: | |
search_index = pickle.load(f) | |
except FileNotFoundError: | |
source_docs = list(get_file_text(pdf_url)) | |
source_chunks = [] | |
splitter = CharacterTextSplitter(separator=" ", | |
chunk_size=1024, | |
chunk_overlap=0) | |
for source in source_docs: | |
for chunk in splitter.split_text(source.page_content): | |
source_chunks.append( | |
Document(page_content=chunk, metadata=source.metadata)) | |
search_index = FAISS.from_documents( | |
source_chunks, OpenAIEmbeddings(openai_api_key=openai_api_key)) | |
with open(file_name, "wb") as f: | |
pickle.dump(search_index, f) | |
input_documents = search_index.similarity_search(question, k=k) | |
questiontokensCount = num_tokens_from_string(str(input_documents)+ question, "gpt2") | |
while questiontokensCount > 3500: | |
k = k-1 | |
input_documents = search_index.similarity_search(question, k=k) | |
questiontokensCount = num_tokens_from_string(str(input_documents)+ question, "gpt2") | |
stuffchain = chain( | |
{ | |
"input_documents": input_documents, | |
"question": question, | |
}, | |
return_only_outputs=True, | |
)["output_text"] | |
f.close() | |
return {"stuffchain": stuffchain, "source": str(pdf_url)} | |
app.run(host="0.0.0.0", port="8080") #run app |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment