-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathradix_tree.py
120 lines (99 loc) · 3.7 KB
/
radix_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# The implementation of radix tree refers to SGLang (https://arxiv.org/abs/2312.07104).
from collections import defaultdict
class TreeNode:
def __init__(self):
self.children = defaultdict(TreeNode)
self.parent = None
self.value = None
def match(key, seq):
i = 0
for k, w in zip(key, seq):
if k != w:
break
i += 1
return i
class RadixCache:
def __init__(self):
self.reset()
def reset(self):
self.root_node = TreeNode()
self.root_node.value = []
def match_prefix(self, key):
value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
# if value:
# value = torch.concat(value)
# return value, last_node[0]
return value
def insert(self, key, value=None):
if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
def total_size(self):
return self._total_size_helper(self.root_node)
##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node):
for c_key, child in node.children.items():
prefix_len = match(c_key, key)
if prefix_len != 0:
if prefix_len < len(c_key):
new_node = self._split_node(c_key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
else:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
break
def _split_node(self, key, child, split_len):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len:]: child}
new_node.parent = child.parent
new_node.value = child.value[:split_len]
child.parent = new_node
child.value = child.value[split_len:]
new_node.parent.children[key[:split_len]] = new_node
del new_node.parent.children[key]
return new_node
def _insert_helper(self, node, key, value):
for c_key, child in node.children.items():
prefix_len = match(c_key, key)
if prefix_len == len(c_key):
if prefix_len == len(key):
return prefix_len
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)
if prefix_len:
new_node = self._split_node(c_key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.value = value
node.children[key] = new_node
return 0
def _print_helper(self, node, indent):
for key, child in node.children.items():
print(" " * indent, len(key), key[:100])
self._print_helper(child, indent=indent + 2)
def _total_size_helper(self, node):
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
if __name__ == "__main__":
tree = RadixCache()
tree.insert((1,2,3))
tree.insert((2,5,1))
tree.insert((1,2,7,9))
tree.pretty_print()
print(tree.match_prefix((1,2,3,8,2)))