Last active
March 10, 2025 13:47
-
-
Save mkeywood1/9e8411aef44cf18009aa3e4776501c08 to your computer and use it in GitHub Desktop.
Jupyter notebook for fine tuning a T5 small model to generate SQL from natural language
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "21833f8d", | |
"metadata": {}, | |
"source": [ | |
"# Fine Tuning of a SQL Model\n", | |
"\n", | |
"### Inspired by https://huggingface.co/cssupport/t5-small-awesome-text-to-sql\n", | |
"\n", | |
"### Datasets:\n", | |
"- https://huggingface.co/datasets/b-mc2/sql-create-context\n", | |
"- https://huggingface.co/datasets/Clinton/Text-to-sql-v1\n", | |
"- https://huggingface.co/datasets/knowrohit07/know_sql" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "1f78e14f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"from datasets import Dataset, DatasetDict, load_dataset, interleave_datasets, load_from_disk\n", | |
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer\n", | |
"import torch\n", | |
"import time\n", | |
"import evaluate\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"\n", | |
"import warnings\n", | |
"warnings.filterwarnings(\"ignore\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "cd00c140", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.cuda.is_available()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "1971e6c5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_name='t5-small'\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
"\n", | |
"original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n", | |
"original_model = original_model.to('cuda')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f5e5419b", | |
"metadata": {}, | |
"source": [ | |
"# Load Datasets" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "ee806dfd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loaded Merged Dataset\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"DatasetDict({\n", | |
" train: Dataset({\n", | |
" features: ['question', 'context', 'answer'],\n", | |
" num_rows: 118695\n", | |
" })\n", | |
" test: Dataset({\n", | |
" features: ['question', 'context', 'answer'],\n", | |
" num_rows: 14835\n", | |
" })\n", | |
" validation: Dataset({\n", | |
" features: ['question', 'context', 'answer'],\n", | |
" num_rows: 14838\n", | |
" })\n", | |
"})" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"try:\n", | |
" dataset = load_from_disk(\"merged_dataset\")\n", | |
" print(\"Loaded Merged Dataset\")\n", | |
"except:\n", | |
" dataset_scc_train = load_dataset(\"b-mc2/sql-create-context\", split='train[:80%]')\n", | |
" dataset_scc_test = load_dataset(\"b-mc2/sql-create-context\", split='train[-20%:-10%]')\n", | |
" dataset_scc_val = load_dataset(\"b-mc2/sql-create-context\", split='train[-10%:]')\n", | |
"\n", | |
" dataset_tts_train = load_dataset(\"Clinton/Text-to-sql-v1\", split='train[:80%]')\n", | |
" dataset_tts_train = dataset_tts_train.remove_columns(['source', 'text'])\n", | |
" dataset_tts_train = dataset_tts_train.rename_columns({'instruction': 'question', 'input': 'context', 'response': 'answer'})\n", | |
" dataset_tts_test = load_dataset(\"Clinton/Text-to-sql-v1\", split='train[-20%:-10%]')\n", | |
" dataset_tts_test = dataset_tts_test.remove_columns(['source', 'text'])\n", | |
" dataset_tts_test = dataset_tts_test.rename_columns({'instruction': 'question', 'input': 'context', 'response': 'answer'})\n", | |
" dataset_tts_val = load_dataset(\"Clinton/Text-to-sql-v1\", split='train[-10%:]')\n", | |
" dataset_tts_val = dataset_tts_val.remove_columns(['source', 'text'])\n", | |
" dataset_tts_val = dataset_tts_val.rename_columns({'instruction': 'question', 'input': 'context', 'response': 'answer'})\n", | |
"\n", | |
" dataset_ks_train = load_dataset(\"knowrohit07/know_sql\", split='validation[:80%]')\n", | |
" dataset_ks_test = load_dataset(\"knowrohit07/know_sql\", split='validation[-20%:-10%]')\n", | |
" dataset_ks_val = load_dataset(\"knowrohit07/know_sql\", split='validation[-10%:]')\n", | |
"\n", | |
" dataset = DatasetDict({ 'train': interleave_datasets([dataset_scc_train, dataset_tts_train, dataset_ks_train]),\n", | |
" 'test': interleave_datasets([dataset_scc_test, dataset_tts_test, dataset_ks_test]),\n", | |
" 'validation': interleave_datasets([dataset_scc_val, dataset_tts_val, dataset_ks_val])})\n", | |
"\n", | |
" dataset.save_to_disk(\"merged_dataset\")\n", | |
" print(\"Merged and Saved Dataset\")\n", | |
"\n", | |
"dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "89b95075", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'question': 'On what Date did the Away team essendon play?',\n", | |
" 'context': 'CREATE TABLE table_name_11 (date VARCHAR, away_team VARCHAR)',\n", | |
" 'answer': 'SELECT date FROM table_name_11 WHERE away_team = \"essendon\"'}" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dataset['test'][0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e8a79425", | |
"metadata": {}, | |
"source": [ | |
"# Preprocess the Datasets\n", | |
"\n", | |
"You need to convert the datasets into explicit instructions for the LLM.\n", | |
"\n", | |
"Then preprocess the prompt-response dataset into tokens and pull out their input_ids." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "ad26693b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Map: 0%| | 0/118695 [00:00<?, ? examples/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Map: 0%| | 0/14835 [00:00<?, ? examples/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Map: 0%| | 0/14838 [00:00<?, ? examples/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5dd41fae15dd43d5bd3bc5a44bcb2603", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Saving the dataset (0/2 shards): 0%| | 0/118695 [00:00<?, ? examples/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Saving the dataset (0/1 shards): 0%| | 0/14835 [00:00<?, ? examples/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Saving the dataset (0/1 shards): 0%| | 0/14838 [00:00<?, ? examples/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Tokenized and Saved Dataset\n" | |
] | |
} | |
], | |
"source": [ | |
"def tokenize_function(example):\n", | |
" \n", | |
"# print(len(example[\"question\"]))\n", | |
" start_prompt = \"Tables:\\n\"\n", | |
" middle_prompt = \"\\n\\nQuestion:\\n\"\n", | |
" end_prompt = \"\\n\\nAnswer:\\n\"\n", | |
" \n", | |
" data_zip = zip(example['context'], example['question'])\n", | |
" prompt = [start_prompt + context + middle_prompt + question + end_prompt for context, question in data_zip]\n", | |
" example['input_ids'] = tokenizer(prompt, padding=\"max_length\", truncation=True, return_tensors=\"pt\").input_ids\n", | |
" example['labels'] = tokenizer(example['answer'], padding=\"max_length\", truncation=True, return_tensors=\"pt\").input_ids\n", | |
"# print(prompt[0])\n", | |
"# print()\n", | |
" \n", | |
" return example\n", | |
"\n", | |
"# The dataset actually contains 3 diff splits: train, validation, test.\n", | |
"# The tokenize_function code is handling all data across all splits in batches.\n", | |
"\n", | |
"try:\n", | |
" tokenized_datasets = load_from_disk(\"tokenized_datasets\")\n", | |
" print(\"Loaded Tokenized Dataset\")\n", | |
"except:\n", | |
" tokenized_datasets = dataset.map(tokenize_function, batched=True)\n", | |
" tokenized_datasets = tokenized_datasets.remove_columns(['question', 'context', 'answer'])\n", | |
" \n", | |
" tokenized_datasets.save_to_disk(\"tokenized_datasets\")\n", | |
" print(\"Tokenized and Saved Dataset\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "fe4bfa16", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dict_keys(['train', 'test', 'validation'])\n", | |
"dict_keys(['input_ids', 'labels'])\n", | |
"[4398, 7, 10, 205, 4386, 6048, 332, 17098, 819, 41]\n", | |
"[3, 23143, 14196, 2847, 17161, 599, 1935, 61, 21680, 819]\n", | |
"DatasetDict({\n", | |
" train: Dataset({\n", | |
" features: ['input_ids', 'labels'],\n", | |
" num_rows: 118695\n", | |
" })\n", | |
" test: Dataset({\n", | |
" features: ['input_ids', 'labels'],\n", | |
" num_rows: 14835\n", | |
" })\n", | |
" validation: Dataset({\n", | |
" features: ['input_ids', 'labels'],\n", | |
" num_rows: 14838\n", | |
" })\n", | |
"})\n" | |
] | |
} | |
], | |
"source": [ | |
"print(tokenized_datasets.keys())\n", | |
"print(tokenized_datasets['train'][0].keys())\n", | |
"print(tokenized_datasets['train'][0]['input_ids'][:10])\n", | |
"print(tokenized_datasets['train'][0]['labels'][:10])\n", | |
"print(tokenized_datasets)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "6efaa5b9", | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Shapes of the datasets:\n", | |
"Training: (118695, 2)\n", | |
"Validation: (14838, 2)\n", | |
"Test: (14835, 2)\n", | |
"DatasetDict({\n", | |
" train: Dataset({\n", | |
" features: ['input_ids', 'labels'],\n", | |
" num_rows: 118695\n", | |
" })\n", | |
" test: Dataset({\n", | |
" features: ['input_ids', 'labels'],\n", | |
" num_rows: 14835\n", | |
" })\n", | |
" validation: Dataset({\n", | |
" features: ['input_ids', 'labels'],\n", | |
" num_rows: 14838\n", | |
" })\n", | |
"})\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"Shapes of the datasets:\")\n", | |
"print(f\"Training: {tokenized_datasets['train'].shape}\")\n", | |
"print(f\"Validation: {tokenized_datasets['validation'].shape}\")\n", | |
"print(f\"Test: {tokenized_datasets['test'].shape}\")\n", | |
"\n", | |
"print(tokenized_datasets)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4e52f581", | |
"metadata": {}, | |
"source": [ | |
"# Test the Model with Zero Shot Inferencing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "1c6a5c0f", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"---------------------------------------------------------------------------------------------------\n", | |
"INPUT PROMPT:\n", | |
"Tables:\n", | |
"CREATE TABLE table_name_11 (date VARCHAR, away_team VARCHAR)\n", | |
"\n", | |
"Question:\n", | |
"On what Date did the Away team essendon play?\n", | |
"\n", | |
"Answer:\n", | |
"\n", | |
"---------------------------------------------------------------------------------------------------\n", | |
"BASELINE HUMAN ANSWER:\n", | |
"SELECT date FROM table_name_11 WHERE away_team = \"essendon\"\n", | |
"\n", | |
"---------------------------------------------------------------------------------------------------\n", | |
"MODEL GENERATION - ZERO SHOT:\n", | |
"Question\n" | |
] | |
} | |
], | |
"source": [ | |
"index = 0\n", | |
"\n", | |
"question = dataset['test'][index]['question']\n", | |
"context = dataset['test'][index]['context']\n", | |
"answer = dataset['test'][index]['answer']\n", | |
"\n", | |
"prompt = f\"\"\"Tables:\n", | |
"{context}\n", | |
"\n", | |
"Question:\n", | |
"{question}\n", | |
"\n", | |
"Answer:\n", | |
"\"\"\"\n", | |
"\n", | |
"inputs = tokenizer(prompt, return_tensors='pt')\n", | |
"inputs = inputs.to('cuda')\n", | |
"\n", | |
"output = tokenizer.decode(\n", | |
" original_model.generate(\n", | |
" inputs[\"input_ids\"], \n", | |
" max_new_tokens=200,\n", | |
" )[0], \n", | |
" skip_special_tokens=True\n", | |
")\n", | |
"\n", | |
"dash_line = '-'.join('' for x in range(100))\n", | |
"print(dash_line)\n", | |
"print(f'INPUT PROMPT:\\n{prompt}')\n", | |
"print(dash_line)\n", | |
"print(f'BASELINE HUMAN ANSWER:\\n{answer}\\n')\n", | |
"print(dash_line)\n", | |
"print(f'MODEL GENERATION - ZERO SHOT:\\n{output}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a22a7f40", | |
"metadata": {}, | |
"source": [ | |
"So pretty poor - aka garbage." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c8832fa9", | |
"metadata": {}, | |
"source": [ | |
"# Perform Full Fine-Tuning" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "24f2d995", | |
"metadata": {}, | |
"source": [ | |
"### 2 Epochs\n", | |
"\n", | |
"#### 5e-3\n", | |
"\n", | |
"Time Taken = 2h 49m 1s on a laptop with a GeForce RTX 3070 GPU\n", | |
"\n", | |
"Training Loss = 0.023100\n", | |
"\n", | |
"Validation Loss = 0.013285" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "94988713", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"try:\n", | |
" finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\"finetuned_model_2_epoch\")\n", | |
" finetuned_model = finetuned_model.to('cuda')\n", | |
" to_train = False\n", | |
"\n", | |
"except:\n", | |
" to_train = True\n", | |
" finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n", | |
" finetuned_model = finetuned_model.to('cuda')\n", | |
" tokenizer = AutoTokenizer.from_pretrained(model_name)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "ba6d32dd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: total: 0 ns\n", | |
"Wall time: 0 ns\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"if to_train:\n", | |
" output_dir = f'./sql-training-{str(int(time.time()))}'\n", | |
"\n", | |
" training_args = TrainingArguments(\n", | |
" output_dir=output_dir,\n", | |
" learning_rate=5e-3,\n", | |
" num_train_epochs=2,\n", | |
" per_device_train_batch_size=16, # batch size per device during training\n", | |
" per_device_eval_batch_size=16, # batch size for evaluation\n", | |
" weight_decay=0.01,\n", | |
" logging_steps=50,\n", | |
" evaluation_strategy='steps', # evaluation strategy to adopt during training\n", | |
" eval_steps=500, # number of steps between evaluation\n", | |
" )\n", | |
"\n", | |
" trainer = Trainer(\n", | |
" model=finetuned_model,\n", | |
" args=training_args,\n", | |
" train_dataset=tokenized_datasets['train'],\n", | |
" eval_dataset=tokenized_datasets['validation'],\n", | |
" )\n", | |
" \n", | |
" trainer.train()\n", | |
" \n", | |
" finetuned_model.save_pretrained(\"finetuned_model_2_epoch\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "4507aa94", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\"finetuned_model_2_epoch\")\n", | |
"finetuned_model = finetuned_model.to('cuda')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "131bc210", | |
"metadata": {}, | |
"source": [ | |
"# Test the Fine Tuned Model with Zero Shot Inferencing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "f3fdfcf5", | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"---------------------------------------------------------------------------------------------------\n", | |
"INPUT PROMPT:\n", | |
"Tables:\n", | |
"CREATE TABLE employees (\n", | |
" EMPLOYEE_ID decimal(6,0),\n", | |
" FIRST_NAME varchar(20),\n", | |
" LAST_NAME varchar(25),\n", | |
" EMAIL varchar(25),\n", | |
" PHONE_NUMBER varchar(20),\n", | |
" HIRE_DATE date,\n", | |
" JOB_ID varchar(10),\n", | |
" SALARY decimal(8,2),\n", | |
" COMMISSION_PCT decimal(2,2),\n", | |
" MANAGER_ID decimal(6,0),\n", | |
" DEPARTMENT_ID decimal(4,0)\n", | |
")\n", | |
"\n", | |
"CREATE TABLE jobs (\n", | |
" JOB_ID varchar(10),\n", | |
" JOB_TITLE varchar(35),\n", | |
" MIN_SALARY decimal(6,0),\n", | |
" MAX_SALARY decimal(6,0)\n", | |
")\n", | |
"\n", | |
"CREATE TABLE locations (\n", | |
" LOCATION_ID decimal(4,0),\n", | |
" STREET_ADDRESS varchar(40),\n", | |
" POSTAL_CODE varchar(12),\n", | |
" CITY varchar(30),\n", | |
" STATE_PROVINCE varchar(25),\n", | |
" COUNTRY_ID varchar(2)\n", | |
")\n", | |
"\n", | |
"CREATE TABLE countries (\n", | |
" COUNTRY_ID varchar(2),\n", | |
" COUNTRY_NAME varchar(40),\n", | |
" REGION_ID decimal(10,0)\n", | |
")\n", | |
"\n", | |
"CREATE TABLE job_history (\n", | |
" EMPLOYEE_ID decimal(6,0),\n", | |
" START_DATE date,\n", | |
" END_DATE date,\n", | |
" JOB_ID varchar(10),\n", | |
" DEPARTMENT_ID decimal(4,0)\n", | |
")\n", | |
"\n", | |
"CREATE TABLE regions (\n", | |
" REGION_ID decimal(5,0),\n", | |
" REGION_NAME varchar(25)\n", | |
")\n", | |
"\n", | |
"CREATE TABLE departments (\n", | |
" DEPARTMENT_ID decimal(4,0),\n", | |
" DEPARTMENT_NAME varchar(30),\n", | |
" MANAGER_ID decimal(6,0),\n", | |
" LOCATION_ID decimal(4,0)\n", | |
")\n", | |
"\n", | |
"Question:\n", | |
"For those employees who did not have any job in the past, give me the comparison about the amount of job_id over the job_id , and group by attribute job_id, and list from low to high by the JOB_ID please.\n", | |
"\n", | |
"Answer:\n", | |
"\n", | |
"---------------------------------------------------------------------------------------------------\n", | |
"BASELINE HUMAN ANSWER:\n", | |
"SELECT JOB_ID, COUNT(JOB_ID) FROM employees WHERE NOT EMPLOYEE_ID IN (SELECT EMPLOYEE_ID FROM job_history) GROUP BY JOB_ID ORDER BY JOB_ID\n", | |
"\n", | |
"---------------------------------------------------------------------------------------------------\n", | |
"FINE-TUNED MODEL - ZERO SHOT:\n", | |
"SELECT JOB_ID, COUNT(JOB_ID) FROM employees WHERE NOT EMPLOYEE_ID IN (SELECT EMPLOYEE_ID FROM job_history) GROUP BY JOB_ID ORDER BY JOB_ID\n" | |
] | |
} | |
], | |
"source": [ | |
"index = 0\n", | |
"# index = len(dataset['test'])-200\n", | |
"\n", | |
"question = dataset['test'][index]['question']\n", | |
"context = dataset['test'][index]['context']\n", | |
"answer = dataset['test'][index]['answer']\n", | |
"\n", | |
"prompt = f\"\"\"Tables:\n", | |
"{context}\n", | |
"\n", | |
"Question:\n", | |
"{question}\n", | |
"\n", | |
"Answer:\n", | |
"\"\"\"\n", | |
"\n", | |
"inputs = tokenizer(prompt, return_tensors='pt')\n", | |
"inputs = inputs.to('cuda')\n", | |
"\n", | |
"output = tokenizer.decode(\n", | |
" finetuned_model.generate(\n", | |
" inputs[\"input_ids\"], \n", | |
" max_new_tokens=200,\n", | |
" )[0], \n", | |
" skip_special_tokens=True\n", | |
")\n", | |
"\n", | |
"dash_line = '-'.join('' for x in range(100))\n", | |
"print(dash_line)\n", | |
"print(f'INPUT PROMPT:\\n{prompt}')\n", | |
"print(dash_line)\n", | |
"print(f'BASELINE HUMAN ANSWER:\\n{answer}\\n')\n", | |
"print(dash_line)\n", | |
"print(f'FINE-TUNED MODEL - ZERO SHOT:\\n{output}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "69ec82ff", | |
"metadata": {}, | |
"source": [ | |
"# Evaluate the Model Quantitatively (with ROUGE Metric)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "8e665b3b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Token indices sequence length is longer than the specified maximum sequence length for this model (1115 > 512). Running this sequence through the model will result in indexing errors\n" | |
] | |
} | |
], | |
"source": [ | |
"# Perform inferences for test dataset. Do 25 only, due to time it takes.\n", | |
"\n", | |
"questions = dataset['test'][0:25]['question']\n", | |
"contexts = dataset['test'][0:25]['context']\n", | |
"human_baseline_answers = dataset['test'][0:25]['answer']\n", | |
"\n", | |
"original_model_answers = []\n", | |
"finetuned_model_answers = []\n", | |
"\n", | |
"for idx, question in enumerate(questions):\n", | |
" \n", | |
" prompt = f\"\"\"Tables:\n", | |
"{contexts[idx]}\n", | |
"\n", | |
"Question:\n", | |
"{question}\n", | |
"\n", | |
"Answer:\n", | |
"\"\"\"\n", | |
" \n", | |
" input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", | |
" input_ids = input_ids.to('cuda')\n", | |
"\n", | |
" human_baseline_text_output = human_baseline_answers[idx]\n", | |
" \n", | |
" original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=300))\n", | |
" original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)\n", | |
" original_model_answers.append(original_model_text_output)\n", | |
" \n", | |
" finetuned_model_outputs = finetuned_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=300))\n", | |
" finetuned_model_text_output = tokenizer.decode(finetuned_model_outputs[0], skip_special_tokens=True)\n", | |
" finetuned_model_answers.append(finetuned_model_text_output)\n", | |
"\n", | |
"zipped_summaries = list(zip(human_baseline_answers, original_model_answers, finetuned_model_answers))\n", | |
" \n", | |
"df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_answers', 'original_model_answers', 'finetuned_model_answers'])\n", | |
"# df" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b00766ae", | |
"metadata": {}, | |
"source": [ | |
"Compute ROUGE score for this subset of the data." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "18975f9d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ORIGINAL MODEL:\n", | |
"{'rouge1': 0.031970284742291306, 'rouge2': 0.005, 'rougeL': 0.03070044347245003, 'rougeLsum': 0.03121247624254732}\n", | |
"FINE-TUNED MODEL:\n", | |
"{'rouge1': 0.923359923692127, 'rouge2': 0.8863291968739871, 'rougeL': 0.9176464597549342, 'rougeLsum': 0.9182149521321223}\n" | |
] | |
} | |
], | |
"source": [ | |
"rouge = evaluate.load('rouge')\n", | |
"\n", | |
"original_model_results = rouge.compute(\n", | |
" predictions=original_model_answers,\n", | |
" references=human_baseline_answers[0:len(original_model_answers)],\n", | |
" use_aggregator=True,\n", | |
" use_stemmer=True,\n", | |
")\n", | |
"print('ORIGINAL MODEL:')\n", | |
"print(original_model_results)\n", | |
"\n", | |
"\n", | |
"finetuned_model_results = rouge.compute(\n", | |
" predictions=finetuned_model_answers,\n", | |
" references=human_baseline_answers[0:len(finetuned_model_answers)],\n", | |
" use_aggregator=True,\n", | |
" use_stemmer=True,\n", | |
")\n", | |
"print('FINE-TUNED MODEL:')\n", | |
"print(finetuned_model_results)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "fc7ef16d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I repeated all the steps here. For the sample schema and complex sample question it gives expected answer but for multiple other simple questions on same schema it does all wrong answers:
List all departments and the name of their managers (if any).
Find the average salary of all employees in the 'IT_PROG' job.
Show the history of jobs (job titles) held by a specific employee (e.g., EMPLOYEE_ID = 101).