Skip to content

Instantly share code, notes, and snippets.

@palewire
Created October 15, 2024 16:28
Show Gist options
  • Save palewire/25b8af47fd0102e698da5518a96154b2 to your computer and use it in GitHub Desktop.
Save palewire/25b8af47fd0102e698da5518a96154b2 to your computer and use it in GitHub Desktop.
"New School" LLM Classifier
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "ebd8543c-2473-4059-8236-7c2bb2284a35",
"metadata": {},
"source": [
"# \"New School\" LLM Classifier\n",
"\n",
"An example of how you can do things now"
]
},
{
"cell_type": "markdown",
"id": "2b8c6d7f-6754-42b8-8a22-9704ad7d89f8",
"metadata": {},
"source": [
"## Import Python tools"
]
},
{
"cell_type": "code",
"execution_count": 177,
"id": "d565bfbf-e93c-40dd-8eda-81dffa622b36",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"import json\n",
"\n",
"from retry import retry\n",
"import pandas as pd\n",
"\n",
"# If we use the Anthropic AI, all we need is this one import!\n",
"from anthropic import Anthropic"
]
},
{
"cell_type": "markdown",
"id": "8ece6418-9347-4752-84b7-855cf4a662c5",
"metadata": {},
"source": [
"We'll read in our sample again, but this time it's only for testing. No training code will be necessary."
]
},
{
"cell_type": "code",
"execution_count": 152,
"id": "fa1fd023-f4bd-4029-853b-199242ac0a51",
"metadata": {},
"outputs": [],
"source": [
"training_df = pd.read_csv(\"./sample.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 153,
"id": "3a63a27a-9c04-4ceb-8d52-13565f721f94",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>app</th>\n",
" <th>headline</th>\n",
" <th>is_politics</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>NYTimes</td>\n",
" <td>‘Shortcuts Everywhere’: Quality issues have pl...</td>\n",
" <td>n</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>NYTimes</td>\n",
" <td>‘We're Going to Do Our Job’: Speaker Mike John...</td>\n",
" <td>y</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>CBS News</td>\n",
" <td>\"60 Minutes\" reports: Here's what the rivalry ...</td>\n",
" <td>n</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>CBS News</td>\n",
" <td>\"Face the Nation\": DHS Secretary Alejandro May...</td>\n",
" <td>y</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>CBS News</td>\n",
" <td>\"History is watching\": President Biden blames ...</td>\n",
" <td>y</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" app headline is_politics\n",
"0 NYTimes ‘Shortcuts Everywhere’: Quality issues have pl... n\n",
"1 NYTimes ‘We're Going to Do Our Job’: Speaker Mike John... y\n",
"2 CBS News \"60 Minutes\" reports: Here's what the rivalry ... n\n",
"3 CBS News \"Face the Nation\": DHS Secretary Alejandro May... y\n",
"4 CBS News \"History is watching\": President Biden blames ... y"
]
},
"execution_count": 153,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_df.head()"
]
},
{
"cell_type": "markdown",
"id": "44d24116-b95c-4f35-99d5-dc05a47559a4",
"metadata": {},
"source": [
"Now we connect to the Anthropic client, which requires register with their site, putting down a credit card and getting an API token."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "76d2b2f2-6196-4c70-95d4-70cf24e9d6ee",
"metadata": {},
"outputs": [],
"source": [
"# Import the Anthropic client\n",
"client = Anthropic(\n",
" api_key=os.getenv(\"ANTHROPIC_API_KEY\"),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "19bb9583-0a98-4e2c-8700-ffdbd8dceaa6",
"metadata": {},
"source": [
"Let's write a \"system prompt\" that will instruct the LLM on how we want it to behave."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "688be377-e358-4e93-bc99-89a7cef919cb",
"metadata": {},
"outputs": [],
"source": [
"system = \"\"\"\n",
"You are a text classifier. Your job is to identify news headlines that are about politics.\n",
"\n",
"You should read the provided headline and determine, yes or no, if it explictly discusses politics..\n",
"\n",
"In cases where the headline discusses politics, you should return 'y'.\n",
"\n",
"In cases where the headline does not discuss politics, you should return 'n'.\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"id": "01701e52-4191-413c-84b3-4cf68a5a025b",
"metadata": {},
"source": [
"Write a request to Anthropic to run the classification against a random headline"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "ecfd241a-dd34-4d74-b746-941016e2d72d",
"metadata": {},
"outputs": [],
"source": [
"random_headline = training_df.sample(1).iloc[0].headline"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "185fd5dc-4086-4454-92a5-410aee85a8ac",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Stocks making the biggest moves premarket: Goldman Sachs, Snap One, Salesforce and more'"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"random_headline"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "669c8967-df24-44d3-98e0-3d05b98e7ed0",
"metadata": {},
"outputs": [],
"source": [
"answer = client.messages.create(\n",
" # Send it your message, in the same way you would shitpost into ChatGPT\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": random_headline},\n",
" ],\n",
" # Include the system prompt too\n",
" system=system,\n",
" model=\"claude-3-haiku-20240307\", # Set the model to use\n",
" max_tokens=1024, # How long the request and response are allowed to be\n",
" temperature=0, # Set this to minimize \"creativity\" and ask the LLM to stick to the facts\n",
")"
]
},
{
"cell_type": "markdown",
"id": "354d38ee-87e4-475e-b6d8-ea142e09ffdf",
"metadata": {},
"source": [
"Print the answer"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "d4f49ec0-c0a3-4013-8375-b6068071d75b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'y'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer.content[0].text"
]
},
{
"cell_type": "markdown",
"id": "8272fe04-c754-4a4c-a48f-de6ebb6d47a6",
"metadata": {},
"source": [
"Now rewrite that into a function so you can loop over each row"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "1c338013-eb30-41ff-8bb3-5cf0a25e8037",
"metadata": {},
"outputs": [],
"source": [
"def classify_headline(headline: str):\n",
" # Do the same request\n",
" answer = client.messages.create(\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": headline},\n",
" ],\n",
" system=system,\n",
" model=\"claude-3-haiku-20240307\", \n",
" max_tokens=1024,\n",
" temperature=0,\n",
" )\n",
" # Let it sleep between each row, to avoid pissing off the API\n",
" time.sleep(1.5)\n",
"\n",
" # Return the result\n",
" result = answer.content[0].text\n",
"\n",
" # If the answer isn't 'y' or 'n', return None\n",
" if result not in ['y', 'n']:\n",
" return None\n",
" \n",
" return result"
]
},
{
"cell_type": "markdown",
"id": "5d1aae7e-3170-4d02-8d95-11555cd8446d",
"metadata": {},
"source": [
"Loop through your dataframe and run it against each row."
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "a5a62070-1dff-49df-8b4b-945aa195861f",
"metadata": {},
"outputs": [],
"source": [
"training_df['anthropic'] = training_df['headline'].apply(classify_headline)"
]
},
{
"cell_type": "markdown",
"id": "0bab7bef-3c4d-440f-93bf-51aaec8029bd",
"metadata": {},
"source": [
"But that takes forever, and costs money each time we hit. so we should batch them instead."
]
},
{
"cell_type": "markdown",
"id": "24e5ad1d-61a0-4670-a978-b730b7cb31df",
"metadata": {},
"source": [
"We start by revising the system prompt with an extra bit at the bottom"
]
},
{
"cell_type": "code",
"execution_count": 186,
"id": "5a9b5615-23dc-4375-beae-ecf1f75dd832",
"metadata": {},
"outputs": [],
"source": [
"system = \"\"\"\n",
"You are a text classifier. Your job is to identify news headlines that are about politics. You should read the provided headline and determine, yes or no, if it discusses politics or political figures.\n",
"\n",
"You will be provided a list of headlines separated by new lines. You should return the result for each headline a JSON list.\n",
"input: 7p ET Judge Napolitano: talks Trump’s Supreme Court win on Rob Schmitt, tune in, start FREE trial: NewsmaxPlus.com\\nStocks making the biggest moves premarket: Goldman Sachs, Snap One, Salesforce and more\n",
"output: [\"y\", \"n\"]\n",
"\n",
"Do not provide anything other than the JSON response. Do not provide any additional text. You must provide a 'y' or 'n' for every headline. If you're giving 20 headlines, you should return a JSON list with 20 answers. No exceptions.\n",
"\n",
"In cases where the headline discusses politics, you should return 'y'. In cases where the headline does not discuss politics, you should return 'n'. A headline should return y when it centers on a head of state or other elected official, like a president, prime minister, senator or royal.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 187,
"id": "98cc57bd-90f8-47fe-8eab-3ed1115a92f3",
"metadata": {},
"outputs": [],
"source": [
"@retry(tries=3, delay=1)\n",
"def classify_headline_batch(headline_list: list[str]) -> list[dict]:\n",
" # Do the same request\n",
" answer = client.messages.create(\n",
" messages=[\n",
" # Give an example\n",
" {\"role\": \"user\", \"content\": \"7p ET Judge Napolitano: talks Trump’s Supreme Court win on Rob Schmitt, tune in, start FREE trial: NewsmaxPlus.com\\nStocks making the biggest moves premarket: Goldman Sachs, Snap One, Salesforce and more\"},\n",
" {\"role\": \"assistant\", \"content\": json.dumps([\"y\", \"n\"])},\n",
" # Notice how I joined them together into a comma-delimited list\n",
" {\"role\": \"user\", \"content\": \"\\n\".join(headline_list)},\n",
" ],\n",
" system=system,\n",
" model=\"claude-3-5-sonnet-20240620\", \n",
" max_tokens=1024,\n",
" temperature=0,\n",
" )\n",
" \n",
" # Parse the model's answer\n",
" text = answer.content[0].text\n",
" try:\n",
" answer_list = json.loads(text)\n",
" except Exception as e:\n",
" print(\"COULD NOT PARSE RESULT\")\n",
" print(text)\n",
" raise e\n",
"\n",
" # Make sure they fit\n",
" try:\n",
" assert len(answer_list) == len(headline_list)\n",
" except AssertionError as e:\n",
" print(f\"{len(headline_list)} headlines\")\n",
" print(f\"{len(answer_list)} answers\")\n",
" raise e\n",
" \n",
" # Combine the headlines and answers into a list of dictionaries\n",
" result_list = []\n",
" for headline, answer in zip(headline_list, answer_list):\n",
" if answer not in [\"y\", \"n\"]:\n",
" answer = None\n",
" result_list.append(\n",
" {\"headline\": headline, \"is_politics\": answer}\n",
" )\n",
"\n",
" # Let it sleep between each row, to avoid pissing off the API\n",
" time.sleep(1.5)\n",
"\n",
" # Return the result\n",
" return result_list"
]
},
{
"cell_type": "code",
"execution_count": 188,
"id": "235c2937-2814-43e2-aff5-a1f24f8a6fb1",
"metadata": {},
"outputs": [],
"source": [
"headline_list = training_df.headline.tolist()"
]
},
{
"cell_type": "code",
"execution_count": 189,
"id": "40096789-25c3-42c0-9ad9-ff05c558707c",
"metadata": {},
"outputs": [],
"source": [
"for i in range(0, len(headline_list), 20):\n",
" # Get the batch of headlines\n",
" batch = headline_list[i : i + 20]\n",
"\n",
" # Get the attribution for the batch\n",
" attribution_list = classify_headline_batch(batch)\n",
"\n",
" # Loop through the batch and add the attribution to the export_df\n",
" for result in attribution_list:\n",
" training_df.loc[\n",
" training_df[\"headline\"] == result[\"headline\"], \"anthropic\"\n",
" ] = result[\"is_politics\"]"
]
},
{
"cell_type": "markdown",
"id": "4b400cad-f492-4222-aeee-30fab612b712",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
}
},
"source": [
"Now check how right it was, compared against our human classifications"
]
},
{
"cell_type": "code",
"execution_count": 190,
"id": "ebda5a28-a2f6-4ba4-8834-8c953dedadc4",
"metadata": {},
"outputs": [],
"source": [
"training_df['correct'] = training_df['is_politics'] == training_df['anthropic']"
]
},
{
"cell_type": "markdown",
"id": "91761fdc-8397-445d-be6d-fcbb5266b9f3",
"metadata": {},
"source": [
"Now lets look at our results using the same system"
]
},
{
"cell_type": "code",
"execution_count": 191,
"id": "a7c65896-2d85-4ed3-8110-b824fde9838d",
"metadata": {},
"outputs": [],
"source": [
"from sklearn import metrics"
]
},
{
"cell_type": "code",
"execution_count": 192,
"id": "6f19ca9a-c089-4f30-96b8-803f1bf5d10d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" n 0.91 0.81 0.86 100\n",
" y 0.83 0.92 0.87 100\n",
"\n",
" accuracy 0.86 200\n",
" macro avg 0.87 0.86 0.86 200\n",
"weighted avg 0.87 0.86 0.86 200\n",
"\n"
]
}
],
"source": [
"print(\n",
" metrics.classification_report(\n",
" training_df[\"is_politics\"],\n",
" training_df[\"anthropic\"],\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "969bb4d7-8424-48ff-a14c-bdfeaf76f984",
"metadata": {},
"source": [
"Better results. And there's even more you can do to improve it."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment