Created
August 22, 2024 10:49
-
-
Save janakiramm/5c4abaa9bbdabb4ed25317c0a75ebb4d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "9bbd56b4-079b-4658-9690-8db19c602dd5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pymilvus import MilvusClient\n", | |
"from pymilvus import connections\n", | |
"from openai import OpenAI\n", | |
"from dotenv import load_dotenv\n", | |
"import os\n", | |
"import ast" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "68c521a0-e52d-48f6-b2b0-d9b78c010799", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"load_dotenv()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "85b861a4-53a8-400b-ad7d-81279d4a660b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"LLM_URI=os.getenv(\"LLM_URI\")\n", | |
"EMBED_URI=os.getenv(\"EMBED_URI\")\n", | |
"VECTORDB_URI=os.getenv(\"VECTORDB_URI\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "a187cff2-91c5-4b60-819d-d5abb806bd95", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"NIM_API_KEY=os.getenv(\"NIM_API_KEY\")\n", | |
"ZILIZ_API_KEY=os.getenv(\"ZILIZ_API_KEY\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "9976b4b6-adec-44ee-a719-1e3b866ff509", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"llm_client = OpenAI(\n", | |
" api_key=NIM_API_KEY,\n", | |
" base_url=LLM_URI\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "01805361-4500-43ad-b043-5064ab1311f9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"embedding_client = OpenAI(\n", | |
" api_key=NIM_API_KEY,\n", | |
" base_url=EMBED_URI\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "9423cf23-d14d-456d-8c94-50c905bc52a2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vectordb_client = MilvusClient(\n", | |
" uri=VECTORDB_URI,\n", | |
" token=ZILIZ_API_KEY\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "35a799db-667d-47b8-af07-b22205966765", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if vectordb_client.has_collection(collection_name=\"india_facts\"):\n", | |
" vectordb_client.drop_collection(collection_name=\"india_facts\")\n", | |
"\n", | |
"vectordb_client.create_collection(\n", | |
" collection_name=\"india_facts\",\n", | |
" dimension=1024, \n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "9133c393-805c-47aa-af09-81b4368fae5e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"docs = [\n", | |
" \"India is the seventh-largest country by land area in the world.\",\n", | |
" \"The Indus Valley Civilization, one of the world's oldest, originated in India around 3300 BCE.\",\n", | |
" \"The game of chess, originally called 'Chaturanga,' was invented in India during the Gupta Empire.\",\n", | |
" \"India is home to the world's largest democracy, with over 900 million eligible voters.\",\n", | |
" \"The Indian mathematician Aryabhata was the first to explain the concept of zero in the 5th century.\",\n", | |
" \"India has the second-largest population in the world, with over 1.4 billion people.\",\n", | |
" \"The Kumbh Mela, held every 12 years, is the largest religious gathering in the world, attracting millions of devotees.\",\n", | |
" \"India is the birthplace of four major world religions: Hinduism, Buddhism, Jainism, and Sikhism.\",\n", | |
" \"The Indian Space Research Organisation (ISRO) successfully sent a spacecraft to Mars on its first attempt in 2014.\",\n", | |
" \"India's Varanasi is considered one of the world's oldest continuously inhabited cities, with a history dating back over 3,000 years.\"\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "7ccfb283-15e6-43f7-8f1e-8ed61d0c2f28", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def embed(docs):\n", | |
" response = embedding_client.embeddings.create(\n", | |
" input=docs,\n", | |
" model=\"nvidia/nv-embedqa-e5-v5\",\n", | |
" encoding_format=\"float\",\n", | |
" extra_body={\"input_type\": \"query\", \"truncate\": \"NONE\"}\n", | |
" )\n", | |
" vectors = [embedding_data.embedding for embedding_data in response.data]\n", | |
" return vectors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "efd2f98a-7534-4eb8-a1bd-e932e8f756b2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vectors=embed(docs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "ab81f2e7-7240-414d-a35a-f1b5775e8bbd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = [\n", | |
" {\"id\": i, \"vector\": vectors[i], \"text\": docs[i], \"subject\": \"history\"}\n", | |
" for i in range(len(vectors))\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "773869a6-d5e0-480a-9baf-b69191f0629f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'insert_count': 10, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 'cost': 0}" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"vectordb_client.insert(collection_name=\"india_facts\", data=data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "7c7f9a08-0a09-4323-afff-406d9ce1a7f2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"query_vectors = embed([\"ISRO\"])\n", | |
"\n", | |
"res = vectordb_client.search(\n", | |
" collection_name=\"india_facts\", \n", | |
" data=query_vectors, \n", | |
" limit=2, \n", | |
" output_fields=[\"text\", \"subject\"],\n", | |
")\n", | |
"\n", | |
"#print(res)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "b37a0650-eb8a-4995-b8a9-d64260f2441b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def retrieve(query):\n", | |
" query_vectors = embed([query])\n", | |
"\n", | |
" search_results = vectordb_client.search(\n", | |
" collection_name=\"india_facts\",\n", | |
" data=query_vectors,\n", | |
" output_fields=[\"text\", \"subject\"]\n", | |
" )\n", | |
"\n", | |
" all_texts = []\n", | |
" for item in search_results:\n", | |
" try:\n", | |
" evaluated_item = ast.literal_eval(item) if isinstance(item, str) else item\n", | |
" except:\n", | |
" evaluated_item = item\n", | |
" \n", | |
" if isinstance(evaluated_item, list):\n", | |
" all_texts.extend(subitem['entity']['text'] for subitem in evaluated_item if isinstance(subitem, dict) and 'entity' in subitem and 'text' in subitem['entity'])\n", | |
" elif isinstance(evaluated_item, dict) and 'entity' in evaluated_item and 'text' in evaluated_item['entity']:\n", | |
" all_texts.append(evaluated_item['entity']['text'])\n", | |
" \n", | |
" return \" \".join(all_texts)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "c4d75523-9538-48b1-93d7-764fe4d321e0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def generate(context, question):\n", | |
" prompt = f'''\n", | |
" Based on the context: {context}\n", | |
" \n", | |
" Please answer the question: {question}\n", | |
" ''' \n", | |
" system_prompt='''\n", | |
" You are a helpful assistant that answers questions based on the given context.\\n\n", | |
" Don't add anything to the response. \\n\n", | |
" If you cannot find the answer within the context, say I do not know. \n", | |
" '''\n", | |
" completion = llm_client.chat.completions.create(\n", | |
" model=\"meta/llama3-8b-instruct\",\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": system_prompt},\n", | |
" {\"role\": \"user\", \"content\": prompt}\n", | |
" ],\n", | |
" temperature=0,\n", | |
" top_p=1,\n", | |
" max_tokens=1024\n", | |
" )\n", | |
" return completion.choices[0].message.content" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "22cd1e12-5ac4-4620-8e1b-7dfe6a3e608c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def chat(prompt):\n", | |
" context=retrieve(prompt)\n", | |
" response=generate(context,prompt)\n", | |
" return response" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "a5a047ba-087a-4e15-aee3-e26c06dce782", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#prompt=\"What is ISRO?\"\n", | |
"#prompt=\"What is chess originally called?\"\n", | |
"#prompt=\"When did Indus Valley Civilization orginate?\"\n", | |
"prompt=\"what are the four major world religions?\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "6a6cf3ab-77de-45ee-89d5-bfba76cadafb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The four major world religions are: Hinduism, Buddhism, Jainism, and Sikhism.\n" | |
] | |
} | |
], | |
"source": [ | |
"res=chat(prompt)\n", | |
"print(res)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment