forked from clovaai/subword-qac
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trie.py
73 lines (62 loc) · 2.11 KB
/
trie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""
class Node(object):
def __init__(self):
self.children = {}
self.count = 0
def child(self, char, create=True):
if char not in self.children:
if not create:
return None
self.children[char] = Node()
return self.children[char]
class Trie(object):
def __init__(self, data=None):
self.root = Node()
self.n_nodes = 1
if data:
self.data = sorted(data)
self.add_multiple(self.root, 0, 0, len(data))
def find(self, string, create=True):
node = self.root
for x in string:
node = node.child(x, create)
if node is None:
return None
return node
def add(self, string):
node = self.find(string, True)
node.count += 1
def add_multiple(self, node, depth, l, r):
children = [['', 0]]
for i in range(l, r):
if len(self.data[i]) == depth:
children[0][1] += 1
else:
c = self.data[i][depth]
if children[-1][0] == c:
children[-1][1] += 1
else:
children.append([c, 1])
node.count += children[0][1]
s = l + children[0][1]
for x in children[1:]:
e = s + x[1]
self.add_multiple(node.child(x[0]), depth + 1, s, e)
s = e
def traverse(self, node, prefix, n_candidates, min_freq):
if hasattr(node, 'mpc'):
return
node.mpc = [(node.count, prefix)] if node.count >= min_freq else []
for c, child in node.children.items():
self.traverse(child, prefix + c, n_candidates, min_freq)
node.mpc.extend(child.mpc)
node.mpc = sorted(node.mpc, reverse=True)[:n_candidates]
def get_mpc(self, prefix, n_candidates=10, min_freq=1):
node = self.find(prefix, False)
if node is None:
return []
self.traverse(node, prefix, n_candidates, min_freq)
return [completion for _, completion in node.mpc]