Created
March 28, 2023 16:08
-
-
Save mponty/ba7d74ded87bba6977ba136a49227203 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict, Counter | |
from typing import Tuple, Dict, Any, List | |
from nltk import RegexpTokenizer | |
import re | |
def is_overlap(span: Tuple[int, int], reference_span: Tuple[int, int]) -> bool: | |
l1, r1 = min(*span), max(*span) | |
l2, r2 = min(*reference_span), max(*reference_span) | |
return l1 <= l2 < r1 or l1 < r2 <= r1 or l2 <= l1 < r2 or l2 < r1 <= r2 | |
def heal_fragmented_entities(content, entities): | |
tokenizer_pattern=r'[\w+\.\-]+|[\S+]' | |
regtok = RegexpTokenizer(tokenizer_pattern) | |
token_spans = list(regtok.span_tokenize(content)) | |
relabeling = defaultdict(lambda : defaultdict(list)) | |
for entity in entities: | |
entity_span = (entity['start'], entity['end']) | |
for token_idx, span in enumerate(token_spans): | |
if is_overlap(entity_span, span): | |
relabeling[token_idx][entity['tag']].append(entity) | |
new_entities = [] | |
for token_idx, detected_entities in relabeling.items(): | |
# If it was detected different entities within token then take the most frequent | |
most_occured = sorted(detected_entities.items(), key = lambda tag_entlist: -len(tag_entlist[-1]))[0] | |
tag, old_entities = most_occured | |
new_entity = dict(tag= tag, | |
start = token_spans[token_idx][0], | |
end = token_spans[token_idx][1], | |
value = content[slice(*token_spans[token_idx])]) | |
if all('confidence' in e for e in old_entities): | |
new_entity['confidence'] = sum(e['confidence'] for e in old_entities)/len(old_entities) | |
new_entities.append(new_entity) | |
return new_entities | |
###Example: | |
example = """ | |
private_key = "eyKSJDH_the_long_fragmented_key_JDNK2SDJ"#<-thisisakey | |
""" | |
entities = [ | |
dict( | |
tag='KEY', | |
start = 16, | |
end = 21, | |
value = 'eyKSJ', | |
confidence = 0.9, | |
), | |
dict( | |
tag='KEY', | |
start = 40, | |
end = 45, | |
value = 'ted_k', | |
confidence = 0.5, | |
), | |
dict( | |
tag='PASSWORD', | |
start = 38, | |
end = 40, | |
value = 'en', | |
confidence = 0.7, | |
) | |
] | |
heal_fragmented_entities(example, entities) | |
# [{'tag': 'KEY', | |
# 'start': 16, | |
# 'end': 56, | |
# 'value': 'eyKSJDH_the_long_fragmented_key_JDNK2SDJ', | |
# 'confidence': 0.7}] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment