Last active
July 22, 2021 15:42
-
-
Save ssaavedra/69b5b9865f2f2e9f421951bed5417f0a to your computer and use it in GitHub Desktop.
DiffMatcher is a tool to check an input against a database of known variants.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
DiffMatcher is a tool to check an input against a database of known variants. | |
For example, to determine (and resort to fuzzy-matching if required) whether | |
some unsanitized fields of user-input match a preset of known-good entries. | |
Author: Santiago Saavedra <[email protected]> | |
SPDX-License-Identifier: CC0-1.0 | |
''' | |
from functools import lru_cache | |
from logging import getLogger | |
from typing import Callable, Dict, Generic, Iterable, Optional, Tuple, TypeVar | |
import jellyfish | |
logger = getLogger("eu.ssaavedra.diffmatcher") | |
PartitionKey = TypeVar("PartitionKey") | |
ReverseIndex = TypeVar("ReverseIndex", str, str) | |
Value = TypeVar("Value") | |
class DiffMatcher(Generic[PartitionKey, ReverseIndex, Value]): | |
def __init__( | |
self, | |
src: Iterable[Tuple[PartitionKey, Value, ReverseIndex]], | |
word_similarity_fn: Callable[[ReverseIndex, ReverseIndex], float] = lru_cache(maxsize=1024)( | |
jellyfish.jaro_similarity | |
), | |
name_variants_fn: Callable[[ReverseIndex], Iterable[ReverseIndex]] = lambda x: [x] | |
): | |
"""src must be a tuple of (PartitionKey, ReverseIndex, Value). | |
Example in provinces: (01, 04, Almeria). Because CCAA[01] = Andalucia. | |
Cache names on init so that we only do it once. | |
We want to have all common variants of the names, and we store a | |
reverse-index dictionary with each possible variant. | |
We also store a second dictionary that is pre-indexed by PartitionKey code, in | |
case we can disambiguate, to reduce the search space.""" | |
self.word_similarity_fn = word_similarity_fn | |
self._name_variants = name_variants_fn | |
self.cached_names_all = self._explode_all_names(src) | |
self.cached_names_revindex = { | |
item_name: item_code | |
for ( | |
_, | |
item_code, | |
item_name, | |
) in self.cached_names_all | |
} | |
partition_codes = set(map(lambda x: x[0], src)) | |
self.cached_names_revindex_partitioned = { | |
partition_code: { | |
item_name: item_code for (item_code, item_name) in partition_items | |
} | |
for (partition_code, partition_items) in [ | |
( | |
partition_code, | |
[ | |
(item_code, item_name) | |
for partition_code_2, _, item_code, item_name in self.cached_names_all | |
if partition_code_2 == partition_code | |
], | |
) | |
for partition_code in partition_codes | |
] | |
} | |
def _explode_all_names(self, src: Tuple[PartitionKey, Value, ReverseIndex]): | |
""" | |
Obtain each variant of a possible partition/item combination by looping | |
over all _name_variants in PartitionKey and ReverseIndex. | |
""" | |
for (partition_code, prov_code, prov_name) in src: | |
for prov_name_var in self._name_variants(prov_name): | |
item = ( | |
partition_code, | |
prov_code, | |
prov_name_var, | |
) | |
yield item | |
def get_cached_names( | |
self, partition_code: Optional[PartitionKey] | |
) -> Dict[ReverseIndex, Value]: | |
"""Return the appropriate reverse-index. If we know the PartitionKey we can reduce the search space.""" | |
if isinstance(partition_code, int): | |
partition_code = str(partition_code) | |
if partition_code in self.cached_names_revindex_partitioned: | |
return self.cached_names_revindex_partitioned[partition_code] | |
else: | |
return self.cached_names_revindex | |
def get_item( | |
self, item_name: ReverseIndex, partition_code: PartitionKey = None | |
) -> Value: | |
item_name = item_name | |
if item_name in self.cached_names_revindex: | |
return self.cached_names_revindex[item_name] | |
else: | |
# Go with edit-distances | |
logger.warn( | |
f"Item name: {item_name} not found exactly. Searching via edit-distance." | |
) | |
return self._getitem_fuzzy( | |
item_name, partition_code=partition_code | |
) | |
def __getitem__(self, value) -> Value: | |
return self.get_item(value) | |
def _getitem_fuzzy( | |
self, item_name: ReverseIndex, partition_code: PartitionKey = None | |
) -> Value: | |
max_similarity = None | |
for ( | |
known_item_name, | |
item_code, | |
) in self.get_cached_names(partition_code).items(): | |
similarity = self.word_similarity_fn(item_name, known_item_name) | |
if not max_similarity or similarity > max_similarity: | |
max_similarity = similarity | |
max_similarity_item_code = item_code | |
logger.warn( | |
f"Found item code {max_similarity_item_code} with Jaro similarity of {max_similarity}" | |
) | |
return max_similarity_item_code |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment