Skip to content

Instantly share code, notes, and snippets.

@richwhitjr
Created December 23, 2024 19:41
Show Gist options
  • Save richwhitjr/74d36b4e6e329c240c167efaa6b28d0c to your computer and use it in GitHub Desktop.
Save richwhitjr/74d36b4e6e329c240c167efaa6b28d0c to your computer and use it in GitHub Desktop.
Join two keyed python lists using a LLM for a fuzzy inner join.
"""Join two lists of items using a language model."""
import logging
from typing import Any, Dict, Optional, Tuple, List
import numpy as np
from numpy.typing import ArrayLike
import openai
import tqdm
logger = logging.getLogger(__name__)
class BaseAI:
def __call__(self, system_message, user_message) -> Optional[str]:
return None
def embed(self, text: str) -> Optional[ArrayLike]:
return None
class OpenAI(BaseAI):
def __init__(self):
self.client = openai.OpenAI()
def __call__(self, system_message, user_message) -> Optional[str]:
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_message},
{
"role": "user",
"content": [{"type": "text", "text": user_message}],
},
],
)
prompt = response.choices[0].message.content
logger.debug(prompt)
return prompt
def embed(self, text: str) -> Optional[ArrayLike]:
return np.array(
self.client.embeddings.create(input=[text], model="text-embedding-ada-002")
.data[0]
.embedding
)
_DEFAULT_PROMPT = """
YOU MUST ONLY RETURN TWO ANSWERS AS A SINGLE WORD: FALSE if the two words do not match, TRUE if they do match. THIS
IS VERY IMPORTANT FOR HUMANITY. You are given to words left and right. Using the CONTEXT below as a guide, determine
if the two words refer to the same thing and thus are a close sematic match. If they are a close semantic match, return
TRUE. If they are not a close semantic match, return FALSE. If you are unsure, return FALSE.
CONTEXT: {context}
"""
def embed(items, llm):
embeddings = []
for l in tqdm.tqdm(items):
vec = llm.embed(l[0])
embeddings.append(((l[0], vec), l[1]))
return embeddings
def _cosine(v1, v2):
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
ITEM_TYPE = Tuple[str, Any]
VALUE_TYPE = List[ITEM_TYPE]
COGROUP_VALUE_TYPE = Tuple[VALUE_TYPE, VALUE_TYPE]
def inner_join(
left: List[ITEM_TYPE],
right: List[ITEM_TYPE],
context: str = "",
llm: Optional[BaseAI] = None,
emb_distance: float = 0.6,
) -> List[Tuple[ITEM_TYPE, ITEM_TYPE]]:
"""Inner join two lists of items using a language model.
Args:
left: List of items to join.
right: List of items to join.
context: Context for the join.
llm: Language model to use for the join.
emb_distance: Minimum distance between embeddings for a match.
Returns:
List of tuples of matched items.
"""
if llm is None:
llm = OpenAI()
left_embeddings = embed(left, llm)
right_embeddings = embed(right, llm)
matches: Dict[str, Tuple[List[Tuple[str, Any]], List[Tuple[str, Any]]]] = {}
for (lk, le), lv in tqdm.tqdm(left_embeddings):
for (rk, re), rv in right_embeddings:
match_key = f"{lk} and {rk}"
dist = _cosine(le, re)
logger.debug("Match key: %s, Cosine similarity: %f", match_key, dist)
if dist > emb_distance:
match = matcher((lk, lv), (rk, rv), context, llm)
if match_key not in matches and match:
matches[match_key] = ([], [])
if match:
matches[match_key][0].append((lk, lv))
matches[match_key][1].append((rk, rv))
return_items: List[Tuple[ITEM_TYPE, ITEM_TYPE]] = []
for _, (l, r) in matches.items():
for a in l:
for b in r:
return_items.append((a, b))
return return_items
def matcher(
left: Tuple[str, Any],
right: Tuple[str, Any],
context: str = "",
llm: Optional[BaseAI] = None,
) -> bool:
"""Match two items using a language model.
Args:
left: Item to match.
right: Item to match.
context: Context for the match.
llm: Language model to use for the match.
Returns:
True if the items match, False otherwise.
"""
if llm is None:
llm = OpenAI()
prompt = _DEFAULT_PROMPT.format(context=context)
answer = llm(prompt, f"left = {str(left[0])}, right={str(right[0])}")
if answer is None:
return False
if "FALSE" in answer:
return False
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment