Created
March 5, 2025 06:00
-
-
Save tomitrescak/e940d33e693e5b8b2e87f9140dd122fc to your computer and use it in GitHub Desktop.
Extract and Compare Embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import re | |
import unicodedata | |
from statistics import mean | |
from typing import (TYPE_CHECKING, Any, Callable, List, Tuple, TypedDict, | |
Union, cast) | |
from typing_extensions import NotRequired | |
import numpy as np | |
from bs4 import BeautifulSoup | |
from nltk import tokenize # type:ignore | |
from nltk.tokenize import word_tokenize # type:ignore | |
from sentence_transformers import SentenceTransformer # type: ignore | |
from sentence_transformers.util import cos_sim # type: ignore | |
from torch import Tensor | |
Embeddings = List[Tensor] | |
class EmbedSource(TypedDict): | |
id: int | |
name: str | |
name_embeddings: List[Any] | |
name_chunks: List[str] | |
multiplier: NotRequired[float] | |
description: NotRequired[str] | |
description_embeddings: NotRequired[List[Any] | None] | |
description_chunks: NotRequired[List[str] | None] | |
keywords: NotRequired[List[str] | None] | |
keywords_reg: NotRequired[re.Pattern[str] | None] | |
mean: Any | |
# you have to pass a python dictionary | |
# if you wanna pass text I can modify code | |
if TYPE_CHECKING: | |
from data.notifier import WebsocketNotifier | |
# import pandas as pd | |
# calculate similarity | |
# model = SentenceTransformer('paraphrase-MiniLM-L6-v2') | |
# model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
def clean_text(text: str): | |
soup = BeautifulSoup(text, features="lxml") | |
a = soup.get_text() | |
x = re.sub("\n", " ", a) | |
x = re.sub("\\\\n", " ", x) | |
y = re.sub("\t", " ", x) | |
z = re.sub("\\S+@\\S+", " ", y) | |
w = re.sub("http[s]?\\://\\S+", " ", z) | |
q = unicodedata.normalize("NFKD", w) | |
u = re.sub(r"(\xe9|\362)", "", q) | |
return u | |
def split_large_sentence(text: str): | |
text_tokens: List[str] = word_tokenize(text) | |
result = [text_tokens[i: i + 128] for i in range(0, len(text_tokens), 128)] | |
return result | |
def join_chunks(current_chunks: List[str]): | |
text = (" ").join(current_chunks) | |
text = re.sub(r" ([\.,;'\":\?\!])", r"\g<1>", text) | |
return text | |
def create_chunks(text: str, chunk_size: int = 128): | |
sentences: List[str] = tokenize.sent_tokenize(clean_text(text)) | |
combined: List[str] = [] | |
# current = "" | |
current_chunks = [] | |
for sentence in sentences: | |
chunks = split_large_sentence(sentence) | |
if len(chunks) > 1: | |
for chunk in chunks: | |
if len(current_chunks) + len(chunk) > chunk_size: | |
combined.append(join_chunks(current_chunks)) | |
current_chunks = chunk | |
else: | |
current_chunks.extend(chunk) | |
else: | |
if len(current_chunks) + len(chunks[0]) > chunk_size: | |
combined.append(join_chunks(current_chunks)) | |
current_chunks = chunks[0] | |
else: | |
current_chunks.extend(chunks[0]) | |
if len(current_chunks) > 0: | |
combined.append(join_chunks(current_chunks)) | |
return combined | |
def get_embeddings(text: str, round_to: int | None = 4) -> Tuple[Embeddings, List[str]]: | |
""" | |
Split texts into sentences and get embeddings for each sentence. | |
The final embeddings is the mean of all sentence embeddings. | |
:param text: str. Input text. | |
:return: np.array. Embeddings. | |
""" | |
combined = create_chunks(text, chunk_size=128) | |
result: Embeddings = cast(Embeddings, [model.encode( | |
x).astype(np.double) if round_to is None else np.round(model.encode( # type: ignore | |
x).astype(np.double), round_to) for x in combined]) # type: ignore | |
return (result, combined) | |
def get_embeddings_fast(text: str) -> List[Any]: | |
embeddings, _ = get_embeddings(text) | |
return np.mean(embeddings, axis=0) | |
class EmbedResult(TypedDict): | |
id: int | |
name: str | |
max: float | |
description: float | |
description_length: int | |
title: float | |
keywords: List[str] | |
keyword_count: int | |
values: List[Any] | |
multipliers = {0: 1, 1: 1.4, 2: 1.5, 3: 1.6, 4: 1.7} | |
def keyword_multiplier(count: int) -> float: | |
return multipliers[count] if count in multipliers else 2 | |
def combine(title_rank: float, description_rank: Union[float, None], keyword_count: int) -> float: | |
if description_rank is None: | |
return title_rank if keyword_count == 0 else np.clip(0.5 * keyword_multiplier(keyword_count), 0, 1) | |
return np.clip((0.4 * description_rank + 0.6 * title_rank) * keyword_multiplier(keyword_count), 0, 1) | |
class SimilarityResult(TypedDict): | |
value: float | |
source: str | |
target: str | |
def compare_embeddings( | |
in_embeddings: Embeddings, in_texts: List[str], for_embeddings: Embeddings, for_texts: List[str] | |
): | |
combinations: List[SimilarityResult] = [] | |
for i in range(len(in_embeddings)): | |
for_combinations: List[SimilarityResult] = [] | |
for j in range(len(for_embeddings)): | |
for_combinations.append( | |
{ | |
"value": cos_sim(in_embeddings[i], for_embeddings[j]).numpy()[0][0], | |
"source": in_texts[i], | |
"target": for_texts[j], | |
} | |
) | |
if len(for_combinations) > 0: | |
description_embedding = max( | |
for_combinations, key=lambda x: x["value"]) | |
combinations.append(description_embedding) | |
if len(combinations) > 0: | |
# this one calculates the | |
return mean([x["value"] for x in combinations]) | |
return 0 | |
def calculate_complex_similarity( | |
in_this: List[EmbedSource], | |
description_embeddings: Embeddings, | |
description_texts: List[str], | |
description_text: str, | |
threshold: Union[float, None] = 0.5, | |
title_embeddings: Union[Embeddings, None] = None, | |
title_texts: Union[List[str], None] = None, | |
title_text: Union[str, None] = None, | |
checker: Union[Callable[[int, float, List[EmbedResult]], bool], None] = None, | |
allow_multipliers: bool = True, | |
notifier: WebsocketNotifier | None = None | |
) -> List[EmbedResult]: | |
in_scores: List[EmbedResult] = [] | |
# find the best | |
for i, in_value in enumerate(in_this): | |
if "keywords_reg" in in_value and in_value["keywords_reg"] is not None: | |
keyword_matches = re.compile( | |
in_value["keywords_reg"]).findall(description_text) | |
else: | |
keyword_matches = [] | |
if "description_chunks" in in_value and in_value["description_chunks"] is not None and len(in_value["description_chunks"]) > 0 and "description_embeddings" in in_value and in_value["description_embeddings"] is not None: | |
description_embedding = compare_embeddings( | |
in_value["description_embeddings"], | |
in_value["description_chunks"], | |
description_embeddings, | |
description_texts, | |
) | |
else: | |
description_embedding = None | |
name_embedding = compare_embeddings( | |
in_value["name_embeddings"], | |
in_value["name_chunks"], | |
title_embeddings or description_embeddings, | |
title_texts or description_texts, | |
) | |
if "alternatives" in in_value and len(in_value["alternatives"]) > 0: | |
for alternative in in_value["alternatives"]: | |
alternative_embedding = compare_embeddings( | |
alternative["embeddings"], | |
alternative["name"], | |
title_embeddings or description_embeddings, | |
title_texts or description_texts, | |
) | |
# assign the alternative embedding | |
if alternative_embedding > name_embedding: | |
name_embedding = alternative_embedding | |
combined = combine( | |
name_embedding, description_embedding, 0 | |
) # len(keyword_matches)) # ) len(in_value["keywords"])) | |
maximised = combined * \ | |
(1.5 if allow_multipliers and "multiplier" in in_value and in_value["multiplier"] > 1 else 1) | |
if (threshold is not None and maximised > threshold) or ( | |
checker is not None and checker( | |
in_value["id"], maximised, in_scores) | |
): | |
in_scores.append( | |
{ | |
"id": in_value["id"], | |
"name": in_value["name"] if "name" in in_value else " ".join(in_value["name_chunks"]), | |
"max": maximised, | |
"description": description_embedding or 0, | |
"title": name_embedding, | |
"keyword_count": len(keyword_matches), | |
"keywords": keyword_matches, | |
"values": [], | |
"description_length": len(in_value["description"]) if "description" in in_value else 0, | |
} | |
) | |
if notifier is not None: | |
notifier.notify_progress( | |
i, len(in_this), ' '.join(in_value["name_chunks"]) if "name_chunks" in in_value else "") | |
if notifier is not None: | |
notifier.notify_progress( | |
len(in_this), len(in_this), f"Extracted {len(in_scores)} skills") | |
return sorted(in_scores, reverse=True, key=lambda x: x["max"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment