Last active
March 25, 2025 00:54
-
-
Save virattt/d897059f9f50f9b0b0b1295246c2455f to your computer and use it in GitHub Desktop.
query-rewriting-gpt-mistral-cohere.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/d897059f9f50f9b0b0b1295246c2455f/query-rewriting-gpt-mistral-cohere.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Create Query and Prompt" | |
], | |
"metadata": { | |
"id": "m8HqBNyYrDHb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"query = \"What's going on with Airbnb's numbers?\"\n", | |
"\n", | |
"prompt = \"\"\"\n", | |
"Rewrite the following user query into a clear, specific, and\n", | |
"formal request suitable for retrieving relevant information from a vector database.\n", | |
"Keep in mind that your rewritten query will be sent to a vector database, which\n", | |
"does similarity search for retrieving documents.\n", | |
"\"\"\"" | |
], | |
"metadata": { | |
"id": "3qZTrAtXLPl1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install openai" | |
], | |
"metadata": { | |
"id": "2bY0NapN_z98" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "tavToGb_MJrc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use GPT-4 to rewrite the query" | |
], | |
"metadata": { | |
"id": "bPzoWQhVAmLt" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from openai import OpenAI\n", | |
"import time\n", | |
"import json\n", | |
"\n", | |
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n", | |
"\n", | |
"total_time = 0\n", | |
"num_iterations = 10\n", | |
"response = None\n", | |
"\n", | |
"for _ in range(num_iterations):\n", | |
" # Get start time\n", | |
" start_time = time.time()\n", | |
" # Call the model\n", | |
" response = client.chat.completions.create(\n", | |
" model='gpt-4-0125-preview',\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": prompt},\n", | |
" {\"role\": \"user\", \"content\": query},\n", | |
" ]\n", | |
" )\n", | |
" # Get end time\n", | |
" end_time = time.time()\n", | |
" # Update total execution time (excluding sleep time)\n", | |
" total_time += (end_time - start_time)\n", | |
" # Wait for 1 seconds before the next iteration\n", | |
" time.sleep(1)\n", | |
"\n", | |
"# Calculate the average execution time\n", | |
"avg_time = total_time / num_iterations\n", | |
"\n", | |
"print(f\"Took {avg_time} seconds to rewrite the query.\")" | |
], | |
"metadata": { | |
"id": "Z83h16UuMlMt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "8VZMWffzm0-i" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use GPT-3.5 to rewrite the query" | |
], | |
"metadata": { | |
"id": "AdrLmbzAAsgX" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"total_time = 0\n", | |
"num_iterations = 10\n", | |
"response = None\n", | |
"\n", | |
"for _ in range(num_iterations):\n", | |
" # Get start time\n", | |
" start_time = time.time()\n", | |
" # Call the model\n", | |
" response = client.chat.completions.create(\n", | |
" model='gpt-3.5-turbo-0125',\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": prompt},\n", | |
" {\"role\": \"user\", \"content\": query},\n", | |
" ]\n", | |
" )\n", | |
" # Get end time\n", | |
" end_time = time.time()\n", | |
" # Update total execution time (excluding sleep time)\n", | |
" total_time += (end_time - start_time)\n", | |
" # Wait for 1 seconds before the next iteration\n", | |
" time.sleep(1)\n", | |
"\n", | |
"# Calculate the average execution time\n", | |
"avg_time = total_time / num_iterations\n", | |
"\n", | |
"print(f\"Took {avg_time} seconds to rewrite the query.\")" | |
], | |
"metadata": { | |
"id": "AdpynLvNAvww" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "yFKleKcWD84S" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use Mistral to rewrite the query" | |
], | |
"metadata": { | |
"id": "FMHkITrr-ru2" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install mistralai" | |
], | |
"metadata": { | |
"id": "cYy332j3cMbt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Set your Mistral API key\n", | |
"os.environ[\"MISTRAL_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "rcPNaNTR4leC" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import json\n", | |
"from mistralai.client import MistralClient\n", | |
"from mistralai.models.chat_completion import ChatMessage\n", | |
"\n", | |
"client = MistralClient(api_key=os.environ[\"MISTRAL_API_KEY\"])\n", | |
"\n", | |
"total_time = 0\n", | |
"num_iterations = 10\n", | |
"response = None\n", | |
"\n", | |
"for _ in range(num_iterations):\n", | |
" # Get start time\n", | |
" start_time = time.time()\n", | |
" # Call the model\n", | |
" response = client.chat(\n", | |
" model=\"mistral-medium\",\n", | |
" messages=[\n", | |
" ChatMessage(role=\"system\", content=prompt),\n", | |
" ChatMessage(role=\"user\", content=query)\n", | |
" ]\n", | |
" )\n", | |
" # Get end time\n", | |
" end_time = time.time()\n", | |
" # Update total execution time (excluding sleep time)\n", | |
" total_time += (end_time - start_time)\n", | |
" # Wait for 1 seconds before the next iteration\n", | |
" time.sleep(1)\n", | |
"\n", | |
"# Calculate the average execution time\n", | |
"avg_time = total_time / num_iterations\n", | |
"\n", | |
"print(f\"Took {avg_time} seconds to rewrite the query.\")" | |
], | |
"metadata": { | |
"id": "Z3ZQalMlUUb4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "imAL6_eqUtds" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use Cohere to rewrite the query" | |
], | |
"metadata": { | |
"id": "DrgoAKInw-z8" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install cohere" | |
], | |
"metadata": { | |
"id": "uzEp-k7Xxfla" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Set your Cohere API key\n", | |
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "V7X3rjrb4uAX" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import cohere\n", | |
"\n", | |
"# Get your cohere API key on: www.cohere.com\n", | |
"co = cohere.Client(os.environ[\"COHERE_API_KEY\"])\n", | |
"\n", | |
"total_time = 0\n", | |
"num_iterations = 10\n", | |
"response = None\n", | |
"\n", | |
"for _ in range(num_iterations):\n", | |
" # Get start time\n", | |
" start_time = time.time()\n", | |
" # Call the model\n", | |
" response = co.chat(\n", | |
" message=query,\n", | |
" search_queries_only=True,\n", | |
" )\n", | |
" # Get end time\n", | |
" end_time = time.time()\n", | |
" # Update total execution time (excluding sleep time)\n", | |
" total_time += (end_time - start_time)\n", | |
" # Wait for 1 seconds before the next iteration\n", | |
" time.sleep(1)\n", | |
"\n", | |
"# Calculate the average execution time\n", | |
"avg_time = total_time / num_iterations\n", | |
"\n", | |
"print(f\"Took {avg_time} seconds to rewrite the query.\")" | |
], | |
"metadata": { | |
"id": "Zc-w9vn_xAnY" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rewritten_queries = [query.get('text', None) for query in response.search_queries]\n", | |
"print(rewritten_queries)" | |
], | |
"metadata": { | |
"id": "zdpHKYZXBIx0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
}, | |
"orig_nbformat": 4, | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU" | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment