Last active
February 18, 2018 09:43
-
-
Save blazs/a3ceb065a5b5251579247f7a5fda2d1e to your computer and use it in GitHub Desktop.
A quick-and-dirty trie
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" A simple trie implementation """ | |
class Node: | |
def __init__(self, value=None, count=0): | |
self.value = value | |
self.count = count | |
self.children = {} | |
def __contains__(self, value): | |
return value in self.children | |
def __getitem__(self, value): | |
return self.children[value] | |
def insert(self, chr, value=None): | |
if chr not in self.children: | |
self.children[chr] = Node(value) | |
return self.children[chr] | |
def remove(self, chr): | |
self.children.pop(chr) | |
class IterateLeaves: | |
def __init__(self, node, prefix): | |
self.node = node | |
self.stack = [(self.node, prefix)] | |
def __iter__(self): | |
while self.stack: | |
node, word = self.stack.pop() | |
if node.children: | |
for char, child in node.children.items(): | |
self.stack.append((child, word+char)) | |
else: | |
yield word | |
class Trie: | |
def __init__(self): | |
self.root = Node(None) | |
def __contains__(self, word): | |
node = self.root | |
for chr in word: | |
if chr not in node: | |
return False | |
node = node[chr] | |
return node.value is not None | |
def count(self, prefix): | |
node = self.root | |
for ch in prefix: | |
if ch not in node: | |
return 0 | |
node = node[ch] | |
return node.count | |
def insert(self, word, value=None): | |
return self._insert(self.root, word, value) == 1 | |
def _insert(self, node, word, value, index=0): | |
if index == len(word): | |
ret_val = (node.value is None and value is not None) | |
node.value = value | |
node.count += ret_val | |
return ret_val | |
chr = word[index] | |
if chr not in node: | |
node.children[chr] = Node() | |
cnt_change = self._insert(node[chr], word, value, index+1) | |
node.count += cnt_change | |
return cnt_change | |
def get(self, word): | |
node = self.root | |
for chr in word: | |
if chr in node: | |
node = node[chr] | |
else: | |
return None | |
return node.value | |
def all_words(self, prefix): | |
node = self.root | |
for chr in prefix: | |
if chr in node: | |
node = node[chr] | |
else: | |
return iter([]) | |
return IterateLeaves(node, prefix) | |
def remove(self, word): | |
self._remove(self.root, word) | |
def _remove(self, node, word, idx=0): | |
if idx == len(word): | |
decrease_count = node.value is not None | |
if decrease_count: | |
node.count -= 1 | |
node.value = None | |
return decrease_count | |
chr = word[idx] | |
if chr not in node: | |
return False | |
decrease_count = self._remove(node[chr], word, idx+1) | |
if decrease_count: | |
node.count -= 1 | |
if node[chr].count == 0: | |
node.remove(chr) | |
return decrease_count | |
if __name__ == '__main__': | |
trie = Trie() | |
s = 'Banana' | |
for idx, _ in enumerate(s): | |
trie.insert(s[idx:], idx) | |
idx = trie.get('nana') | |
print(idx, s[idx:]) | |
idx = trie.get('ana') | |
print(idx, s[idx:]) | |
""" | |
trie = Trie() | |
with open('google-10000-english-usa.txt') as f: | |
for word in f: | |
w = word.strip() | |
trie.insert(w, len(w)) | |
#print('# word with prefix "info":', trie.count('info')) | |
for infos in trie.all_words('info'): | |
print(infos) | |
print(trie.get('information')) | |
print(trie.count('')) | |
""" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment