Created
May 28, 2019 18:31
-
-
Save artsiomkaltovich/cb8e8fe7561cbe2b955695ddded4bae5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
import bisect | |
from collections import deque | |
from typing import Union, List | |
class SuffixTreeBuildingMethod(enum.Enum): | |
naive = enum.auto() | |
ukkonen = enum.auto() | |
class _TreeElem: | |
def __init__(self, data=None, index=None, children_type=list): | |
self.node = data | |
if children_type == list: | |
self.children = [_TreeElem(index)] if index else [] | |
elif children_type == dict: | |
self.children = {index: _TreeElem(index)} if index else {} | |
def add(self, suffix): | |
current = self | |
while True: | |
pos = bisect.bisect(current.children, suffix) | |
if not self._is_prefix_presented(current, pos, suffix): | |
pos = pos - 1 | |
if self._is_prefix_presented(current, pos, suffix): | |
general = current.children[pos].extract_general_prefix(suffix) | |
if general < len(current.children[pos].node): # added suffix and current vertex has common prefix | |
self._split_and_insert_common_prefix(current, general, pos) | |
if general < len(suffix.node): # vertex is a prefix case | |
suffix.node = suffix.node[general:] | |
else: | |
# vertex is already presented, we have to add just index | |
bisect.insort(current.children[pos].children, suffix.children[0]) | |
break | |
current = current.children[pos] | |
else: | |
bisect.insort(current.children, suffix) | |
break | |
def _is_prefix_presented(self, current, pos, suffix): | |
return pos > -1 and pos < len(current.children) \ | |
and type(current.children[pos].node) == str and suffix.node.startswith(current.children[pos].node[0]) | |
def _split_and_insert_common_prefix(self, current, general, pos): | |
rest = _TreeElem(current.children[pos].node[general:]) | |
rest.children = current.children[pos].children | |
current.children[pos].node = current.children[pos].node[:general] | |
current.children[pos].children = [rest] | |
def __str__(self): | |
children = ", ".join(str(child) if type(child) == list else str(self.children[child]) | |
for child in self.children) | |
return f"Elem(node={self.node}, children={children})" | |
def __eq__(self, other): | |
return self.node == other.node | |
def __lt__(self, other): | |
return (type(self.node) == int, self.node) < (type(other.node) == int, other.node) | |
def __gt__(self, other): | |
return (type(self.node) == int, self.node) > (type(other.node) == int, other.node) | |
def extract_general_prefix(self, other): | |
for general, (a, b) in enumerate(zip(self.node, other.node)): | |
if a != b: | |
break | |
else: | |
general += 1 | |
return general | |
class SuffixTree: | |
def __init__(self, collection: Union[str, List[str]], method=SuffixTreeBuildingMethod.naive): | |
self._n_string = 0 | |
if method == SuffixTreeBuildingMethod.naive: | |
self._naive(collection) | |
elif method == SuffixTreeBuildingMethod.ukkonen: | |
self._ukkonen(collection) | |
else: | |
raise ValueError("Wrong method, please choose one from SuffixTreeBuildingMethod") | |
def _naive(self, collection): | |
collection = [collection] if isinstance(collection, str) else collection | |
self._tree = _TreeElem() | |
for elem in collection: | |
self.add(elem) | |
def _ukkonen(self, collection): | |
collection = [collection] if isinstance(collection, str) else collection | |
tree = _TreeElem(children_type=dict) | |
for elem in collection: | |
for index, symbol in enumerate(elem): | |
if symbol not in tree.children: | |
tree.children[symbol] = _TreeElem((index, None), children_type=dict) | |
else: | |
self._tree = tree | |
def add(self, string): | |
self._n_string += 1 | |
cur = "" | |
for symbol in reversed(string): | |
cur = symbol + cur | |
self._tree.add(_TreeElem(cur, self._n_string)) | |
def make_common_tree(self): | |
self._n_string = 1 | |
self._tree.children = [child for child in self._tree.children | |
if not (len(child.children) == 1 and type(child.children[0].node) == int)] | |
stack = [(self._tree, child, child.node, False) for child in self._tree.children] | |
longests = [] | |
while stack: | |
parent, current, string, visited = stack[-1] | |
stack[-1] = parent, current, string, True | |
if len(current.children) == 1 and type(current.children[0].node) == int: | |
self._remove_uncommon_vertex(current, parent) | |
else: | |
longests = self._current_longests(longests, string, visited) | |
if visited or len(current.children) == 1: | |
stack.pop() | |
self._remove_endings(current) | |
else: | |
stack.extend((current, child, string + child.node, False) | |
for child in current.children if type(child.node) != int) | |
return longests | |
def _remove_uncommon_vertex(self, current, parent): | |
parent.children.pop(parent.children.index(current)) | |
child = current.children[0] | |
pos = bisect.bisect_left(parent.children, child) | |
if not (0 <= pos < len(parent.children) and parent.children[pos] == child): | |
bisect.insort(parent.children, child) | |
def _current_longests(self, longests, string, visited): | |
if visited: | |
longest_len = len(longests[0]) if longests else 0 | |
if len(string) > longest_len: | |
longests = [string] | |
elif len(string) == longest_len: | |
longests.append(string) | |
return longests | |
def _remove_endings(self, current): | |
for index in range(-1, -len(current.children), -1): | |
node = current.children[index].node | |
if type(node) == int and node != 1: | |
current.children.pop() | |
def to_graphviz(self): | |
number = 1 | |
deq = deque([(self._tree, number)]) | |
nodes = "" | |
edges = "" | |
while deq: | |
current, parent_number = deq.popleft() | |
nodes += str(parent_number) + f' [label="{current.node}"]\n' | |
iterable = current.children.values() if type(current.children) == dict else current.children | |
for current_index, child in enumerate(iterable, 1): | |
number += 1 | |
deq.append((child, number)) | |
edges += str(parent_number) + "--" + str(number) + "\n" | |
return "strict graph G {\n" + nodes + edges + "}" | |
def __str__(self): | |
return str(self._tree) | |
def __repr__(self): | |
return repr(self._tree) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment