Created
December 23, 2024 19:41
-
-
Save richwhitjr/74d36b4e6e329c240c167efaa6b28d0c to your computer and use it in GitHub Desktop.
Join two keyed python lists using a LLM for a fuzzy inner join.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Join two lists of items using a language model.""" | |
import logging | |
from typing import Any, Dict, Optional, Tuple, List | |
import numpy as np | |
from numpy.typing import ArrayLike | |
import openai | |
import tqdm | |
logger = logging.getLogger(__name__) | |
class BaseAI: | |
def __call__(self, system_message, user_message) -> Optional[str]: | |
return None | |
def embed(self, text: str) -> Optional[ArrayLike]: | |
return None | |
class OpenAI(BaseAI): | |
def __init__(self): | |
self.client = openai.OpenAI() | |
def __call__(self, system_message, user_message) -> Optional[str]: | |
response = self.client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": user_message}], | |
}, | |
], | |
) | |
prompt = response.choices[0].message.content | |
logger.debug(prompt) | |
return prompt | |
def embed(self, text: str) -> Optional[ArrayLike]: | |
return np.array( | |
self.client.embeddings.create(input=[text], model="text-embedding-ada-002") | |
.data[0] | |
.embedding | |
) | |
_DEFAULT_PROMPT = """ | |
YOU MUST ONLY RETURN TWO ANSWERS AS A SINGLE WORD: FALSE if the two words do not match, TRUE if they do match. THIS | |
IS VERY IMPORTANT FOR HUMANITY. You are given to words left and right. Using the CONTEXT below as a guide, determine | |
if the two words refer to the same thing and thus are a close sematic match. If they are a close semantic match, return | |
TRUE. If they are not a close semantic match, return FALSE. If you are unsure, return FALSE. | |
CONTEXT: {context} | |
""" | |
def embed(items, llm): | |
embeddings = [] | |
for l in tqdm.tqdm(items): | |
vec = llm.embed(l[0]) | |
embeddings.append(((l[0], vec), l[1])) | |
return embeddings | |
def _cosine(v1, v2): | |
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) | |
ITEM_TYPE = Tuple[str, Any] | |
VALUE_TYPE = List[ITEM_TYPE] | |
COGROUP_VALUE_TYPE = Tuple[VALUE_TYPE, VALUE_TYPE] | |
def inner_join( | |
left: List[ITEM_TYPE], | |
right: List[ITEM_TYPE], | |
context: str = "", | |
llm: Optional[BaseAI] = None, | |
emb_distance: float = 0.6, | |
) -> List[Tuple[ITEM_TYPE, ITEM_TYPE]]: | |
"""Inner join two lists of items using a language model. | |
Args: | |
left: List of items to join. | |
right: List of items to join. | |
context: Context for the join. | |
llm: Language model to use for the join. | |
emb_distance: Minimum distance between embeddings for a match. | |
Returns: | |
List of tuples of matched items. | |
""" | |
if llm is None: | |
llm = OpenAI() | |
left_embeddings = embed(left, llm) | |
right_embeddings = embed(right, llm) | |
matches: Dict[str, Tuple[List[Tuple[str, Any]], List[Tuple[str, Any]]]] = {} | |
for (lk, le), lv in tqdm.tqdm(left_embeddings): | |
for (rk, re), rv in right_embeddings: | |
match_key = f"{lk} and {rk}" | |
dist = _cosine(le, re) | |
logger.debug("Match key: %s, Cosine similarity: %f", match_key, dist) | |
if dist > emb_distance: | |
match = matcher((lk, lv), (rk, rv), context, llm) | |
if match_key not in matches and match: | |
matches[match_key] = ([], []) | |
if match: | |
matches[match_key][0].append((lk, lv)) | |
matches[match_key][1].append((rk, rv)) | |
return_items: List[Tuple[ITEM_TYPE, ITEM_TYPE]] = [] | |
for _, (l, r) in matches.items(): | |
for a in l: | |
for b in r: | |
return_items.append((a, b)) | |
return return_items | |
def matcher( | |
left: Tuple[str, Any], | |
right: Tuple[str, Any], | |
context: str = "", | |
llm: Optional[BaseAI] = None, | |
) -> bool: | |
"""Match two items using a language model. | |
Args: | |
left: Item to match. | |
right: Item to match. | |
context: Context for the match. | |
llm: Language model to use for the match. | |
Returns: | |
True if the items match, False otherwise. | |
""" | |
if llm is None: | |
llm = OpenAI() | |
prompt = _DEFAULT_PROMPT.format(context=context) | |
answer = llm(prompt, f"left = {str(left[0])}, right={str(right[0])}") | |
if answer is None: | |
return False | |
if "FALSE" in answer: | |
return False | |
return True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment