Skip to content

Instantly share code, notes, and snippets.

@janakiramm
Created August 22, 2024 10:49
Show Gist options
  • Save janakiramm/5c4abaa9bbdabb4ed25317c0a75ebb4d to your computer and use it in GitHub Desktop.
Save janakiramm/5c4abaa9bbdabb4ed25317c0a75ebb4d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9bbd56b4-079b-4658-9690-8db19c602dd5",
"metadata": {},
"outputs": [],
"source": [
"from pymilvus import MilvusClient\n",
"from pymilvus import connections\n",
"from openai import OpenAI\n",
"from dotenv import load_dotenv\n",
"import os\n",
"import ast"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "68c521a0-e52d-48f6-b2b0-d9b78c010799",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "85b861a4-53a8-400b-ad7d-81279d4a660b",
"metadata": {},
"outputs": [],
"source": [
"LLM_URI=os.getenv(\"LLM_URI\")\n",
"EMBED_URI=os.getenv(\"EMBED_URI\")\n",
"VECTORDB_URI=os.getenv(\"VECTORDB_URI\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a187cff2-91c5-4b60-819d-d5abb806bd95",
"metadata": {},
"outputs": [],
"source": [
"NIM_API_KEY=os.getenv(\"NIM_API_KEY\")\n",
"ZILIZ_API_KEY=os.getenv(\"ZILIZ_API_KEY\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9976b4b6-adec-44ee-a719-1e3b866ff509",
"metadata": {},
"outputs": [],
"source": [
"llm_client = OpenAI(\n",
" api_key=NIM_API_KEY,\n",
" base_url=LLM_URI\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "01805361-4500-43ad-b043-5064ab1311f9",
"metadata": {},
"outputs": [],
"source": [
"embedding_client = OpenAI(\n",
" api_key=NIM_API_KEY,\n",
" base_url=EMBED_URI\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9423cf23-d14d-456d-8c94-50c905bc52a2",
"metadata": {},
"outputs": [],
"source": [
"vectordb_client = MilvusClient(\n",
" uri=VECTORDB_URI,\n",
" token=ZILIZ_API_KEY\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "35a799db-667d-47b8-af07-b22205966765",
"metadata": {},
"outputs": [],
"source": [
"if vectordb_client.has_collection(collection_name=\"india_facts\"):\n",
" vectordb_client.drop_collection(collection_name=\"india_facts\")\n",
"\n",
"vectordb_client.create_collection(\n",
" collection_name=\"india_facts\",\n",
" dimension=1024, \n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9133c393-805c-47aa-af09-81b4368fae5e",
"metadata": {},
"outputs": [],
"source": [
"docs = [\n",
" \"India is the seventh-largest country by land area in the world.\",\n",
" \"The Indus Valley Civilization, one of the world's oldest, originated in India around 3300 BCE.\",\n",
" \"The game of chess, originally called 'Chaturanga,' was invented in India during the Gupta Empire.\",\n",
" \"India is home to the world's largest democracy, with over 900 million eligible voters.\",\n",
" \"The Indian mathematician Aryabhata was the first to explain the concept of zero in the 5th century.\",\n",
" \"India has the second-largest population in the world, with over 1.4 billion people.\",\n",
" \"The Kumbh Mela, held every 12 years, is the largest religious gathering in the world, attracting millions of devotees.\",\n",
" \"India is the birthplace of four major world religions: Hinduism, Buddhism, Jainism, and Sikhism.\",\n",
" \"The Indian Space Research Organisation (ISRO) successfully sent a spacecraft to Mars on its first attempt in 2014.\",\n",
" \"India's Varanasi is considered one of the world's oldest continuously inhabited cities, with a history dating back over 3,000 years.\"\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7ccfb283-15e6-43f7-8f1e-8ed61d0c2f28",
"metadata": {},
"outputs": [],
"source": [
"def embed(docs):\n",
" response = embedding_client.embeddings.create(\n",
" input=docs,\n",
" model=\"nvidia/nv-embedqa-e5-v5\",\n",
" encoding_format=\"float\",\n",
" extra_body={\"input_type\": \"query\", \"truncate\": \"NONE\"}\n",
" )\n",
" vectors = [embedding_data.embedding for embedding_data in response.data]\n",
" return vectors"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "efd2f98a-7534-4eb8-a1bd-e932e8f756b2",
"metadata": {},
"outputs": [],
"source": [
"vectors=embed(docs)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ab81f2e7-7240-414d-a35a-f1b5775e8bbd",
"metadata": {},
"outputs": [],
"source": [
"data = [\n",
" {\"id\": i, \"vector\": vectors[i], \"text\": docs[i], \"subject\": \"history\"}\n",
" for i in range(len(vectors))\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "773869a6-d5e0-480a-9baf-b69191f0629f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'insert_count': 10, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 'cost': 0}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vectordb_client.insert(collection_name=\"india_facts\", data=data)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7c7f9a08-0a09-4323-afff-406d9ce1a7f2",
"metadata": {},
"outputs": [],
"source": [
"query_vectors = embed([\"ISRO\"])\n",
"\n",
"res = vectordb_client.search(\n",
" collection_name=\"india_facts\", \n",
" data=query_vectors, \n",
" limit=2, \n",
" output_fields=[\"text\", \"subject\"],\n",
")\n",
"\n",
"#print(res)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b37a0650-eb8a-4995-b8a9-d64260f2441b",
"metadata": {},
"outputs": [],
"source": [
"def retrieve(query):\n",
" query_vectors = embed([query])\n",
"\n",
" search_results = vectordb_client.search(\n",
" collection_name=\"india_facts\",\n",
" data=query_vectors,\n",
" output_fields=[\"text\", \"subject\"]\n",
" )\n",
"\n",
" all_texts = []\n",
" for item in search_results:\n",
" try:\n",
" evaluated_item = ast.literal_eval(item) if isinstance(item, str) else item\n",
" except:\n",
" evaluated_item = item\n",
" \n",
" if isinstance(evaluated_item, list):\n",
" all_texts.extend(subitem['entity']['text'] for subitem in evaluated_item if isinstance(subitem, dict) and 'entity' in subitem and 'text' in subitem['entity'])\n",
" elif isinstance(evaluated_item, dict) and 'entity' in evaluated_item and 'text' in evaluated_item['entity']:\n",
" all_texts.append(evaluated_item['entity']['text'])\n",
" \n",
" return \" \".join(all_texts)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c4d75523-9538-48b1-93d7-764fe4d321e0",
"metadata": {},
"outputs": [],
"source": [
"def generate(context, question):\n",
" prompt = f'''\n",
" Based on the context: {context}\n",
" \n",
" Please answer the question: {question}\n",
" ''' \n",
" system_prompt='''\n",
" You are a helpful assistant that answers questions based on the given context.\\n\n",
" Don't add anything to the response. \\n\n",
" If you cannot find the answer within the context, say I do not know. \n",
" '''\n",
" completion = llm_client.chat.completions.create(\n",
" model=\"meta/llama3-8b-instruct\",\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ],\n",
" temperature=0,\n",
" top_p=1,\n",
" max_tokens=1024\n",
" )\n",
" return completion.choices[0].message.content"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "22cd1e12-5ac4-4620-8e1b-7dfe6a3e608c",
"metadata": {},
"outputs": [],
"source": [
"def chat(prompt):\n",
" context=retrieve(prompt)\n",
" response=generate(context,prompt)\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a5a047ba-087a-4e15-aee3-e26c06dce782",
"metadata": {},
"outputs": [],
"source": [
"#prompt=\"What is ISRO?\"\n",
"#prompt=\"What is chess originally called?\"\n",
"#prompt=\"When did Indus Valley Civilization orginate?\"\n",
"prompt=\"what are the four major world religions?\""
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "6a6cf3ab-77de-45ee-89d5-bfba76cadafb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The four major world religions are: Hinduism, Buddhism, Jainism, and Sikhism.\n"
]
}
],
"source": [
"res=chat(prompt)\n",
"print(res)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment