Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save tomercagan/0feb75c1943f4a7131d74200b4cea2c3 to your computer and use it in GitHub Desktop.
Save tomercagan/0feb75c1943f4a7131d74200b4cea2c3 to your computer and use it in GitHub Desktop.
Example of simple Text to SQL based on source CSV file using SQLite and OpenAI (through langchain)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text to SQL from CSV\n",
"\n",
"This is a sort notebook of how to role your own, simple text-to-sql using CSV file.\n",
"\n",
"The data file used for this example is Kaggle's [Wine Review](https://www.kaggle.com/datasets/zynicide/wine-reviews?resource=download) dataset.\n",
"\n",
"This notebook uses OpenAI api through the [openai python library](https://github.com/openai/openai-python) but the general principle should hold for other LLMs.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"Install dependencies and import them and set your OpenAI API key\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"%pip install pandas openai tabulate langchain langchain_openai"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from sqlite3 import Connection\n",
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load your CSV data\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>country</th>\n",
" <th>description</th>\n",
" <th>designation</th>\n",
" <th>points</th>\n",
" <th>price</th>\n",
" <th>province</th>\n",
" <th>region_1</th>\n",
" <th>region_2</th>\n",
" <th>variety</th>\n",
" <th>winery</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>US</td>\n",
" <td>This tremendous 100% varietal wine hails from ...</td>\n",
" <td>Martha's Vineyard</td>\n",
" <td>96</td>\n",
" <td>235.0</td>\n",
" <td>California</td>\n",
" <td>Napa Valley</td>\n",
" <td>Napa</td>\n",
" <td>Cabernet Sauvignon</td>\n",
" <td>Heitz</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Spain</td>\n",
" <td>Ripe aromas of fig, blackberry and cassis are ...</td>\n",
" <td>Carodorum Selección Especial Reserva</td>\n",
" <td>96</td>\n",
" <td>110.0</td>\n",
" <td>Northern Spain</td>\n",
" <td>Toro</td>\n",
" <td>NaN</td>\n",
" <td>Tinta de Toro</td>\n",
" <td>Bodega Carmen Rodríguez</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>US</td>\n",
" <td>Mac Watson honors the memory of a wine once ma...</td>\n",
" <td>Special Selected Late Harvest</td>\n",
" <td>96</td>\n",
" <td>90.0</td>\n",
" <td>California</td>\n",
" <td>Knights Valley</td>\n",
" <td>Sonoma</td>\n",
" <td>Sauvignon Blanc</td>\n",
" <td>Macauley</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" country description \\\n",
"0 US This tremendous 100% varietal wine hails from ... \n",
"1 Spain Ripe aromas of fig, blackberry and cassis are ... \n",
"2 US Mac Watson honors the memory of a wine once ma... \n",
"\n",
" designation points price province \\\n",
"0 Martha's Vineyard 96 235.0 California \n",
"1 Carodorum Selección Especial Reserva 96 110.0 Northern Spain \n",
"2 Special Selected Late Harvest 96 90.0 California \n",
"\n",
" region_1 region_2 variety winery \n",
"0 Napa Valley Napa Cabernet Sauvignon Heitz \n",
"1 Toro NaN Tinta de Toro Bodega Carmen Rodríguez \n",
"2 Knights Valley Sonoma Sauvignon Blanc Macauley "
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"csv_path = \"./winemag-data_first150k.csv\" # change the path as needed\n",
"df = pd.read_csv(csv_path, index_col=0) # this dataset has an index column, you might not need it\n",
"df.head(3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create an SQL database from the CSV file\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"150930"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"con = Connection(\"my-csv-data.sqlite\")\n",
"table_name = \"wine_data\"\n",
"index_label = \"idx\"\n",
"df.to_sql(table_name, con, index_label=index_label, if_exists=\"replace\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper methods for execution of SQL (to get results)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import tabulate\n",
"\n",
"def strip_markdown(txt: str) -> str:\n",
" return txt.replace(\"```sql\", \"\").replace(\"```\", \"\")\n",
"\n",
"def run_sql(sql):\n",
" cur = con.execute(strip_markdown(sql))\n",
" for line in tabulate.tabulate(cur.fetchall(), headers=[col[0] for col in cur.description]).split(\"\\n\"):\n",
" print(line)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"country count(*)\n",
"--------- ----------\n",
"US 62397\n",
"Italy 23478\n",
"France 21098\n",
"Spain 8268\n",
"Chile 5816\n",
"Argentina 5631\n",
"Portugal 5322\n",
"Australia 4957\n",
" max(points)\n",
"-------------\n",
" 100\n"
]
}
],
"source": [
"# example query\n",
"run_sql(f\"SELECT country, count(*) from {table_name} group by country order by 2 desc limit 8\")\n",
"\n",
"run_sql(f\"SELECT max(points) from {table_name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Query using LLM\n",
"\n",
"This section describe how to make a simple text-to-sql on top of the database created.\n",
"\n",
"This includes:\n",
"\n",
"- Define some helper methods\n",
"- Define the relevant prompts and chains to query the table\n",
"- Examples.\n",
"\n",
"To work with OpenAI API, you will need to have an API key, then set it in your environment\n",
"\n",
"```bash\n",
"export OPENAI_API_KEY=sk-...\n",
"```\n",
"\n",
"add a cell to do this:\n",
"\n",
"```python\n",
"import os\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
"```\n",
"\n",
"or pass it directly to the creation of `ChatOpenAI` inside `get_sql_chain`:\n",
"\n",
"```python\n",
"llm = ChatOpenAI(model_name=model_name, temperature=temprature, api_key=\"sk-...)\n",
"```\n",
"\n",
"**This is very simple / naive example. There are more robust solution that one should consider:**\n",
"\n",
"- Langchain [CSV agent](https://github.com/langchain-ai/langchain/tree/master/templates/csv-agent/) or [DataFrame Agent](https://python.langchain.com/v0.1/docs/integrations/toolkits/pandas/)\n",
"- [Pandas ai](https://pandas-ai.com/)\n",
"- Tools like [Vanna.ai](https://vanna.ai/)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper methods to get schema information and example data\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def get_ddl(con: Connection, table_name: str) -> str:\n",
" \"\"\"Get the DDL of specified table\"\"\"\n",
" sql = f\"SELECT sql FROM sqlite_schema WHERE name='{table_name}';\"\n",
" data = con.execute(sql).fetchall()\n",
" return data[0][0]\n",
"\n",
"def get_sample_data(con: Connection, table_name: str, num_items: int = 3) -> str:\n",
" \"\"\"Get sample data from the specified table\"\"\"\n",
" sql = f\"\"\"SELECT * FROM {table_name} limit {num_items}\"\"\"\n",
" cur = con.execute(sql)\n",
"\n",
" headers = [d[0] for d in cur.description]\n",
" values = cur.fetchall()\n",
"\n",
" table_data = tabulate.tabulate(values, headers=headers)\n",
"\n",
" return table_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define template, chains\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"ddl = get_ddl(con, table_name)\n",
"table_data = get_sample_data(con, table_name, 4)\n",
"\n",
"_system_prompt_template = f\"\"\"You are an SQL expert.\n",
"\n",
"You are requested to help users investigate the {table_name} table concerning with \n",
"wine information and ranking.\n",
"\n",
"The database used is SQLite so make sure to use the relevant syntax.\n",
"\n",
"The table structure is as follows:\n",
"\n",
"{ddl}\n",
"\n",
"Following as a few rows of data from the table:\n",
"\n",
"{table_data}\n",
"\n",
"Answer the user's queries with the relevant SQL query. Only provide an SQL query\n",
"without any additional text of information. \n",
"\n",
"For all filters, use case-insensitive non-exact 'like' filter. For example, if \n",
"asked about wines from italy, the WHERE clause should be \"WHERE country like '%italy%'\"\n",
"\n",
"For questions that asks a number of items per time unit or country,\n",
"consider using window analytic functions (e.g. ROW_NUMBER, RANK, LEAD, LAG, etc).\n",
"\n",
"For example, if asked for the top scoring wine in a few countries, use the ROW_NUMBER \n",
"window function.\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.messages import BaseMessage\n",
"from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain_openai import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def get_sql_chain():\n",
" \n",
" sql_req_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", _system_prompt_template.format(ddl=ddl, table_data=table_data)),\n",
" MessagesPlaceholder(\"messages\"),\n",
" (\"user\", \"Current question: {question}\"),\n",
" ]\n",
" )\n",
" model_name = \"gpt-4o\"\n",
" temprature = 0\n",
" llm = ChatOpenAI(model_name=model_name, temperature=temprature)\n",
"\n",
" sql_chain = sql_req_prompt | llm\n",
"\n",
" return sql_chain\n",
"\n",
"chain = get_sql_chain()\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"def run_chain_and_get_data(query: str, history: list[BaseMessage]):\n",
" res = chain.invoke(\n",
" input={\n",
" \"question\": query,\n",
" \"messages\": history,\n",
" }\n",
" )\n",
"\n",
" sql = strip_markdown(res.content)\n",
"\n",
" print(sql)\n",
"\n",
" run_sql(sql)\n",
"\n",
" return sql"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"SELECT *\n",
"FROM (\n",
" SELECT *,\n",
" ROW_NUMBER() OVER (PARTITION BY country ORDER BY points DESC) as rn\n",
" FROM wine_data\n",
" WHERE country LIKE '%italy%' OR country LIKE '%spain%'\n",
")\n",
"WHERE rn = 1;\n",
"\n",
" idx country description designation points price province region_1 region_2 variety winery rn\n",
"----- --------- ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ------------------ -------- ------- -------------- ---------------- ---------- ---------- --------------------- ----\n",
"24151 Italy A perfect wine from a classic vintage, the 2007 Masseto (100% Merlot from a 17-acre vineyard of the same name) opens with an unabashedly opulent bouquet of delicious blackberry, cherry, chocolate, vanilla, exotic spice and cinnamon. Masseto excels both in terms of quality of fruit and winemaking and delivers plush, velvety tannins and an extra long, supple finish. It will make a special and valuable collection to your cellar. Masseto 100 460 Tuscany Toscana Merlot Tenuta dell'Ornellaia 1\n",
"10538 Spain Luscious prune and blackberry aromas come with complex notes of graphite, toast and tobacco. The palate on this 80-case, high-end Tinto Fino is about as full as they come while still maintaining impeccable balance. Roasted berry, cassis and salty leather flavors are delicious but also challenging, while baked berry flavors, raw power and ripe tannins define the finish. Drink this outstanding Spanish wine from 2014 through 2024. Clon de la Familia 98 450 Northern Spain Ribera del Duero Tinto Fino Emilio Moro 1\n"
]
}
],
"source": [
"query = \"Show me the top scoring wine for italy and spain\"\n",
"history = [] # you can ask follow-up question etc by populating history\n",
"\n",
"sql = run_chain_and_get_data(query, history)\n"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"SELECT country, points\n",
"FROM (\n",
" SELECT country, points,\n",
" ROW_NUMBER() OVER (PARTITION BY country ORDER BY points DESC) as rn\n",
" FROM wine_data\n",
" WHERE country LIKE '%us%' OR country LIKE '%france%' OR country LIKE '%chile%'\n",
")\n",
"WHERE rn = 1;\n",
"\n",
"country points\n",
"--------- --------\n",
"Australia 100\n",
"Austria 98\n",
"Chile 95\n",
"Cyprus 89\n",
"France 100\n",
"US 100\n",
"US-France 88\n"
]
},
{
"data": {
"text/plain": [
"\"\\nSELECT country, points\\nFROM (\\n SELECT country, points,\\n ROW_NUMBER() OVER (PARTITION BY country ORDER BY points DESC) as rn\\n FROM wine_data\\n WHERE country LIKE '%us%' OR country LIKE '%france%' OR country LIKE '%chile%'\\n)\\nWHERE rn = 1;\\n\""
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.messages import HumanMessage, AIMessage\n",
"history = [\n",
" HumanMessage(query),\n",
" AIMessage(sql)\n",
"]\n",
"query = \"Do the same but show me only points and country for US, france and Chile\"\n",
"\n",
"run_chain_and_get_data(query, history)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@tomercagan
Copy link
Author

tomercagan commented Sep 3, 2024

Note there is a mistake in the second results because it fetches both "US" and "Austria" due to the loose where criterion ((LIKE '%us%')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment