Created
November 24, 2023 19:12
-
-
Save ohmeow/c7fe48e4847db5577dc620068a87991b to your computer and use it in GitHub Desktop.
Functions Calling with OpenHermes 2.5 using python functions and pydantic classes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# openhermes-functions\n", | |
"\n", | |
"Demonstrates how to implement function calling at inference time using the \"OpenHeremes-2.5-Mistral7B\" checkpoint\n", | |
"\n", | |
"Source: https://github.com/abacaj/openhermes-function-calling/blob/main/openhermes-functions.ipynb (or https://nbsanity.com/static/f491f7e30f8e9d70dfc72acf9d841afc/openhermes-functions.html)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import gc, inspect, json, re\n", | |
"import xml.etree.ElementTree as ET\n", | |
"from functools import partial\n", | |
"from typing import get_type_hints\n", | |
"\n", | |
"import transformers\n", | |
"import torch\n", | |
"\n", | |
"from langchain.chains.openai_functions import convert_to_openai_function\n", | |
"from langchain.utils.openai_functions import convert_pydantic_to_openai_function\n", | |
"from langchain.pydantic_v1 import BaseModel, Field, validator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_name = \"teknium/OpenHermes-2.5-Mistral-7B\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Utility Methods" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_model(model_name: str):\n", | |
" tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n", | |
"\n", | |
" with torch.device(\"cuda:0\"):\n", | |
" model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).eval()\n", | |
" \n", | |
" return tokenizer, model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "3410456ddeed476c83bd9c9ed9972615", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"tokenizer, model = load_model(model_name)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"def delete_model(*args):\n", | |
" for var in args:\n", | |
" if var in globals():\n", | |
" del globals()[var]\n", | |
"\n", | |
" gc.collect()\n", | |
" torch.cuda.empty_cache()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"delete_model(\"model\", \"tokenizer\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Function Calling" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### A. Using Python Functions" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Class/Function Examples" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Article:\n", | |
" pass\n", | |
"\n", | |
"class Weather:\n", | |
" pass\n", | |
"\n", | |
"class Directions:\n", | |
" pass\n", | |
"\n", | |
"def calculate_mortgage_payment(loan_amount: int, interest_rate: float, loan_term: int) -> float:\n", | |
" \"\"\"Get the monthly mortgage payment given an interest rate percentage.\"\"\"\n", | |
" \n", | |
" # TODO: you must implement this to actually call it later\n", | |
" pass\n", | |
"\n", | |
"def get_article_details(title: str, authors: list[str], short_summary: str, date_published: str, tags: list[str]) -> Article:\n", | |
" '''Get article details from unstructured article text.\n", | |
"date_published: formatted as \"MM/DD/YYYY\"'''\n", | |
" \n", | |
" # TODO: you must implement this to actually call it later\n", | |
" pass\n", | |
"\n", | |
"def get_weather(zip_code: str) -> Weather:\n", | |
" \"\"\"Get the current weather given a zip code.\"\"\"\n", | |
" \n", | |
" # TODO: you must implement this to actually call it later\n", | |
" pass\n", | |
"\n", | |
"def get_directions(start: str, destination: str) -> Directions:\n", | |
" \"\"\"Get directions from Google Directions API.\n", | |
"start: start address as a string including zipcode (if any)\n", | |
"destination: end address as a string including zipcode (if any)\"\"\"\n", | |
" \n", | |
" # TODO: you must implement this to actually call it later\n", | |
" pass" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Serialization Methods" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_type_name(t):\n", | |
" name = str(t)\n", | |
" if \"list\" in name or \"dict\" in name:\n", | |
" return name\n", | |
" else:\n", | |
" return t.__name__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Article\n", | |
"get_weather\n" | |
] | |
} | |
], | |
"source": [ | |
"print(get_type_name(Article))\n", | |
"print(get_type_name(get_weather))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def serialize_function_to_json(func):\n", | |
" signature = inspect.signature(func)\n", | |
" type_hints = get_type_hints(func)\n", | |
"\n", | |
" function_info = {\n", | |
" \"name\": func.__name__,\n", | |
" \"description\": func.__doc__,\n", | |
" \"parameters\": {\n", | |
" \"type\": \"object\",\n", | |
" \"properties\": {}\n", | |
" },\n", | |
" \"returns\": type_hints.get('return', 'void').__name__\n", | |
" }\n", | |
"\n", | |
" for name, _ in signature.parameters.items():\n", | |
" param_type = get_type_name(type_hints.get(name, type(None)))\n", | |
" function_info[\"parameters\"][\"properties\"][name] = {\"type\": param_type}\n", | |
"\n", | |
" return json.dumps(function_info, indent=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{\n", | |
" \"name\": \"get_article_details\",\n", | |
" \"description\": \"Get article details from unstructured article text.\\ndate_published: formatted as \\\"MM/DD/YYYY\\\"\",\n", | |
" \"parameters\": {\n", | |
" \"type\": \"object\",\n", | |
" \"properties\": {\n", | |
" \"title\": {\n", | |
" \"type\": \"str\"\n", | |
" },\n", | |
" \"authors\": {\n", | |
" \"type\": \"list[str]\"\n", | |
" },\n", | |
" \"short_summary\": {\n", | |
" \"type\": \"str\"\n", | |
" },\n", | |
" \"date_published\": {\n", | |
" \"type\": \"str\"\n", | |
" },\n", | |
" \"tags\": {\n", | |
" \"type\": \"list[str]\"\n", | |
" }\n", | |
" }\n", | |
" },\n", | |
" \"returns\": \"Article\"\n", | |
"}\n", | |
"{\n", | |
" \"name\": \"get_weather\",\n", | |
" \"description\": \"Get the current weather given a zip code.\",\n", | |
" \"parameters\": {\n", | |
" \"type\": \"object\",\n", | |
" \"properties\": {\n", | |
" \"zip_code\": {\n", | |
" \"type\": \"str\"\n", | |
" }\n", | |
" }\n", | |
" },\n", | |
" \"returns\": \"Weather\"\n", | |
"}\n" | |
] | |
} | |
], | |
"source": [ | |
"print(serialize_function_to_json(get_article_details))\n", | |
"print(serialize_function_to_json(get_weather))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### B. Using Pydantic" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Pydantic Examples" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Joke(BaseModel):\n", | |
" \"\"\"Get a joke that includes the setup and punchline\"\"\"\n", | |
" setup: str = Field(description=\"question to set up a joke\")\n", | |
" punchline: str = Field(description=\"answer to resolve the joke\")\n", | |
"\n", | |
" # You can add custom validation logic easily with Pydantic.\n", | |
" @validator(\"setup\")\n", | |
" def question_ends_with_question_mark(cls, field):\n", | |
" if field[-1] != \"?\":\n", | |
" raise ValueError(\"Badly formed question!\")\n", | |
" return field\n", | |
" \n", | |
"class Actor(BaseModel):\n", | |
" \"\"\"Get the filmography for a given actor\"\"\"\n", | |
" name: str = Field(description=\"name of an actor\")\n", | |
" film_names: list[str] = Field(description=\"list of names of films they starred in\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Serialization Methods" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'name': 'Joke',\n", | |
" 'description': 'Get a joke that includes the setup and punchline',\n", | |
" 'parameters': {'title': 'Joke',\n", | |
" 'description': 'Get a joke that includes the setup and punchline',\n", | |
" 'type': 'object',\n", | |
" 'properties': {'setup': {'title': 'Setup',\n", | |
" 'description': 'question to set up a joke',\n", | |
" 'type': 'string'},\n", | |
" 'punchline': {'title': 'Punchline',\n", | |
" 'description': 'answer to resolve the joke',\n", | |
" 'type': 'string'}},\n", | |
" 'required': ['setup', 'punchline']}}" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"convert_pydantic_to_openai_function(Joke)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Inference" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def extract_function_calls(completion):\n", | |
" completion = completion.strip()\n", | |
" pattern = r\"(<multiplefunctions>(.*?)</multiplefunctions>)\"\n", | |
" match = re.search(pattern, completion, re.DOTALL)\n", | |
" if not match:\n", | |
" return None\n", | |
" \n", | |
" multiplefn = match.group(1)\n", | |
" root = ET.fromstring(multiplefn)\n", | |
" functions = root.findall(\"functioncall\")\n", | |
" return [json.loads(fn.text) for fn in functions]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def generate_hermes(prompt, model, tokenizer, generation_config_overrides={}):\n", | |
" fn = \"\"\"{\"name\": \"function_name\", \"arguments\": {\"arg_1\": \"value_1\", \"arg_2\": value_2, ...}}\"\"\"\n", | |
" prompt = f\"\"\"<|im_start|>system\n", | |
"You are a helpful assistant with access to the following functions:\n", | |
"\n", | |
"{serialize_function_to_json(get_weather)}\n", | |
"\n", | |
"{serialize_function_to_json(calculate_mortgage_payment)}\n", | |
"\n", | |
"{serialize_function_to_json(get_directions)}\n", | |
"\n", | |
"{serialize_function_to_json(get_article_details)}\n", | |
"\n", | |
"{convert_pydantic_to_openai_function(Joke)}\n", | |
"\n", | |
"{convert_pydantic_to_openai_function(Actor)}\n", | |
"\n", | |
"To use these functions respond with:\n", | |
"<multiplefunctions>\n", | |
" <functioncall> {fn} </functioncall>\n", | |
" <functioncall> {fn} </functioncall>\n", | |
" ...\n", | |
"</multiplefunctions>\n", | |
"\n", | |
"Edge cases you must handle:\n", | |
"- If there are no functions that match the user request, you will respond politely that you cannot help.<|im_end|>\n", | |
"<|im_start|>user\n", | |
"{prompt}<|im_end|>\n", | |
"<|im_start|>assistant\"\"\"\n", | |
"\n", | |
" generation_config = model.generation_config\n", | |
" generation_config.update(\n", | |
" **{\n", | |
" **{\n", | |
" \"use_cache\": True,\n", | |
" \"do_sample\": True,\n", | |
" \"temperature\": 0.2,\n", | |
" \"top_p\": 1.0,\n", | |
" \"top_k\": 0,\n", | |
" \"max_new_tokens\": 512,\n", | |
" \"eos_token_id\": tokenizer.eos_token_id,\n", | |
" \"pad_token_id\": tokenizer.eos_token_id,\n", | |
" },\n", | |
" **generation_config_overrides,\n", | |
" }\n", | |
" )\n", | |
"\n", | |
" model = model.eval()\n", | |
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n", | |
" n_tokens = inputs.input_ids.numel()\n", | |
"\n", | |
" with torch.inference_mode():\n", | |
" generated_tokens = model.generate(**inputs, generation_config=generation_config)\n", | |
"\n", | |
" return tokenizer.decode(\n", | |
" generated_tokens.squeeze()[n_tokens:], skip_special_tokens=False\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Tests" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "43b892ccbe294393a0757af37bcae8b6", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"tokenizer, model = load_model(model_name=model_name)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[{'name': 'Joke'}]\n", | |
"====================================================================================================\n", | |
"[{'name': 'Actor', 'arguments': {'name': 'Leo Decaprio', 'film_names': []}}]\n", | |
"====================================================================================================\n", | |
"I'm sorry, I don't understand your request. Could you please provide more information or rephrase your question?<|im_end|>\n", | |
"====================================================================================================\n", | |
"[{'name': 'calculate_mortgage_payment', 'arguments': {'loan_amount': 200000, 'interest_rate': 0.04, 'loan_term': 30}}]\n", | |
"====================================================================================================\n", | |
"I'm sorry, I don't have a function to get exchange rates. However, you can try using a search engine or a financial website to find the current exchange rate for USD to EUR.<|im_end|>\n", | |
"====================================================================================================\n", | |
"CPU times: user 7.58 s, sys: 710 ms, total: 8.29 s\n", | |
"Wall time: 8.28 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"generation_func = partial(generate_hermes, model=model, tokenizer=tokenizer)\n", | |
"\n", | |
"prompts = [\n", | |
" \"Tell me a joke\",\n", | |
" \"What movies did Leo Decaprio appear in?\",\n", | |
" \"What's the weather in 10001?\",\n", | |
" \"Determine the monthly mortgage payment for a loan amount of $200,000, an interest rate of 4%, and a loan term of 30 years.\",\n", | |
" \"What's the current exchange rate for USD to EUR?\"\n", | |
"]\n", | |
"\n", | |
"for prompt in prompts:\n", | |
" completion = generation_func(prompt)\n", | |
" functions = extract_function_calls(completion)\n", | |
"\n", | |
" if functions:\n", | |
" print(functions)\n", | |
" else:\n", | |
" print(completion.strip())\n", | |
" print(\"=\"*100)\n", | |
"\n", | |
"delete_model(\"generation_func\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[{'name': 'Actor', 'arguments': {'name': 'Leonardo DiCaprio', 'film_names': ['Titanic']}}, {'name': 'Joke', 'arguments': {'setup': 'Why did the Titanic sink?', 'punchline': 'Because Leo DiCaprio was on board!'}}]\n", | |
"====================================================================================================\n", | |
"[{'name': 'get_weather', 'arguments': {'zip_code': '92024'}}]\n", | |
"====================================================================================================\n", | |
"[{'name': 'get_weather', 'arguments': {'zip_code': '05751'}}, {'name': 'get_weather', 'arguments': {'zip_code': '07030'}}, {'name': 'get_directions', 'arguments': {'start': 'Hoboken, NJ 07030', 'destination': 'Killington, VT 05751'}}]\n", | |
"====================================================================================================\n", | |
"I'm sorry, but I don't have the functionality to provide exchange rates. Could you please ask me something else?<|im_end|>\n", | |
"====================================================================================================\n", | |
"CPU times: user 9.48 s, sys: 231 ms, total: 9.71 s\n", | |
"Wall time: 9.71 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"generation_func = partial(generate_hermes, model=model, tokenizer=tokenizer)\n", | |
"\n", | |
"prompts = [\n", | |
" \"Tell me a joke about one of the movies Leo Decaprio appears in\",\n", | |
" \"What's the weather in Encinitas California (92024)?\",\n", | |
" \"I'm planning a trip to Killington, Vermont (05751) from Hoboken, NJ (07030). Can you get me weather for both locations and directions?\",\n", | |
" \"What's the current exchange rate for USD to EUR?\"\n", | |
"]\n", | |
"\n", | |
"for prompt in prompts:\n", | |
" completion = generation_func(prompt)\n", | |
" functions = extract_function_calls(completion)\n", | |
"\n", | |
" if functions:\n", | |
" print(functions)\n", | |
" else:\n", | |
" print(completion.strip())\n", | |
" print(\"=\"*100)\n", | |
"\n", | |
"delete_model(\"generation_func\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Cleanup" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"delete_model(\"model\", \"tokenizer\", \"generation_func\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "llms", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment