Skip to content

Commit 1434550

Browse files
committedJun 6, 2024
Incorporate Trie into fsm index calculation
1 parent ed44a47 commit 1434550

File tree

2 files changed

+267
-14
lines changed

2 files changed

+267
-14
lines changed
 

Diff for: ‎outlines/fsm/regex.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from numba.typed.typedobjectutils import _nonoptional
2929
from tqdm import tqdm
3030

31+
from outlines.fsm.vocab_trie import VocabTrie
32+
3133
if TYPE_CHECKING:
3234
from outlines.models.tokenizer import Tokenizer
3335

@@ -649,29 +651,38 @@ def state_scan_tokens(
649651
fsm_initial: int,
650652
fsm_finals: Set[int],
651653
vocabulary: List[Tuple[str, Sequence[int]]],
652-
vocabulary_transition_keys: List[Sequence[int]],
654+
vocab_trie: VocabTrie,
653655
start_state: int,
654656
) -> Set[Tuple[int, int]]:
655657
res = set()
656658

657-
for (token, token_ids), token_transition_keys in zip(
658-
vocabulary, vocabulary_transition_keys
659-
):
659+
# Initialize the stack with tokens having no prefixes
660+
stack = numba.typed.List()
661+
for token_transitions_seq in vocab_trie.get_children():
662+
stack.append(token_transitions_seq)
663+
664+
# Process the tokens using the stack
665+
while len(stack) > 0:
666+
token_transition_seq = stack.pop()
660667
state_seq = _walk_fsm(
661668
fsm_transitions,
662669
fsm_initial,
663670
fsm_finals,
664-
token_transition_keys,
671+
token_transition_seq,
665672
start_state,
666673
False,
667674
)
668675

669-
if state_seq is not None and len(state_seq) < len(token_transition_keys):
676+
if state_seq is not None and len(state_seq) < len(token_transition_seq):
670677
continue
671678

672-
for token_id in token_ids:
679+
for token_id in vocab_trie.get_token_ids(token_transition_seq):
673680
res.add((token_id, state_seq[-1]))
674681

682+
# Add successors to the stack
683+
for new_token in vocab_trie.get_children(token_transition_seq):
684+
stack.append(new_token)
685+
675686
return res
676687

677688

@@ -740,18 +751,19 @@ def create_fsm_index_end_to_end(
740751
seen: Set[int] = set()
741752
next_states = {fsm_info.initial}
742753

754+
vocabulary_transitions = get_vocabulary_transition_keys(
755+
fsm_info.alphabet_symbol_mapping,
756+
fsm_info.alphabet_anything_value,
757+
vocabulary,
758+
)
759+
vocab_trie = VocabTrie(vocabulary_transitions, vocabulary)
760+
743761
pbar = tqdm(
744762
total=len(set(fsm_info.transitions.values()))
745763
+ 1, # all transitions plus initial
746764
desc="Compiling FSM index for all state transitions",
747765
)
748766

749-
vocabulary_transition_keys = get_vocabulary_transition_keys(
750-
fsm_info.alphabet_symbol_mapping,
751-
fsm_info.alphabet_anything_value,
752-
vocabulary,
753-
)
754-
755767
while next_states:
756768
start_state = next_states.pop()
757769

@@ -762,7 +774,7 @@ def create_fsm_index_end_to_end(
762774
fsm_info.initial,
763775
fsm_info.finals,
764776
vocabulary,
765-
vocabulary_transition_keys,
777+
vocab_trie,
766778
start_state,
767779
)
768780

Diff for: ‎outlines/fsm/vocab_trie.py

+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import operator
2+
from typing import List, Optional, Sequence, Tuple
3+
4+
import numpy as np
5+
from numba import njit, typed, types
6+
from numba.cpython.hashing import (
7+
_Py_uhash_t,
8+
_PyHASH_XXPRIME_1,
9+
_PyHASH_XXPRIME_2,
10+
_PyHASH_XXPRIME_5,
11+
_PyHASH_XXROTATE,
12+
process_return,
13+
)
14+
from numba.experimental import jitclass, structref
15+
from numba.extending import overload
16+
from numba.typed import Dict
17+
18+
###########################
19+
# Dict With Int[:] Key Impl
20+
###########################
21+
22+
23+
# Register type
24+
@structref.register
25+
class IntArrayDictType(types.StructRef):
26+
"""
27+
Represents a dictionary using int64[:] as keys,
28+
intended for byte-level FSM representation with int64[:] transition.
29+
"""
30+
31+
def preprocess_fields(self, fields):
32+
return tuple(
33+
(name, typ.dtype if isinstance(typ, types.TypeRef) else typ)
34+
for name, typ in fields
35+
)
36+
37+
38+
class IntArrayDict(structref.StructRefProxy):
39+
"""Python proxy"""
40+
41+
@property
42+
def wrapped_dict(self):
43+
return IntArrayDict_get_wrapped_dict(self) # noqa: F821
44+
45+
46+
structref.define_proxy(IntArrayDict, IntArrayDictType, ["wrapped_dict"])
47+
48+
49+
@njit
50+
def hash_key(key):
51+
"""
52+
XXH64 Hash for int64[:] keys
53+
adapted from https://github.com/numba/numba/blob/556545/numba/cpython/hashing.py
54+
"""
55+
acc = _PyHASH_XXPRIME_5
56+
for i in range(key.shape[0]):
57+
x = key[i]
58+
lane = hash(x)
59+
if lane == _Py_uhash_t(-1):
60+
return -1
61+
acc += lane * _PyHASH_XXPRIME_2
62+
acc = _PyHASH_XXROTATE(acc)
63+
acc *= _PyHASH_XXPRIME_1
64+
65+
acc += key.shape[0] ^ (_PyHASH_XXPRIME_5 ^ _Py_uhash_t(3527539))
66+
67+
if acc == _Py_uhash_t(-1):
68+
return process_return(1546275796)
69+
70+
return process_return(acc)
71+
72+
73+
@overload(IntArrayDict)
74+
def custom_int_array_dict_constructor(value_type):
75+
if isinstance(value_type, types.Type):
76+
77+
def impl(value_type):
78+
wrapped_dictionary = Dict.empty(types.intp, value_type)
79+
return IntArrayDict(wrapped_dictionary)
80+
81+
return impl
82+
83+
84+
@overload(operator.getitem)
85+
def ol_int_array_dict_getitem(inst, key):
86+
if isinstance(inst, IntArrayDictType):
87+
88+
def impl(inst, key):
89+
return inst.wrapped_dict[hash_key(key)]
90+
91+
return impl
92+
93+
94+
@overload(operator.setitem)
95+
def ol_int_array_dict_setitem(inst, key, value):
96+
if isinstance(inst, IntArrayDictType):
97+
98+
def impl(inst, key, value):
99+
inst.wrapped_dict[hash_key(key)] = value
100+
101+
return impl
102+
103+
104+
@overload(operator.contains)
105+
def ol_int_array_dict_contains(inst, key):
106+
if isinstance(inst, IntArrayDictType):
107+
108+
def impl(inst, key):
109+
return hash_key(key) in inst.wrapped_dict
110+
111+
return impl
112+
113+
114+
#################
115+
# Vocab Trie Impl
116+
#################
117+
118+
nb_int64_array_type = types.int64[:]
119+
120+
# use intp keys as that is the hash type,
121+
# but the true key type is nb_int64_array_type
122+
IntArrayToIntType = IntArrayDictType(
123+
(("wrapped_dict", types.DictType(types.intp, types.int64)),)
124+
)
125+
IntArrayToIntArrayType = IntArrayDictType(
126+
(("wrapped_dict", types.DictType(types.intp, nb_int64_array_type)),)
127+
)
128+
129+
130+
@jitclass(
131+
[
132+
("token_to_token_key", IntArrayToIntType),
133+
("token_key_to_token", types.DictType(types.int64, nb_int64_array_type)),
134+
(
135+
"token_key_to_child_token_keys",
136+
types.DictType(types.int64, nb_int64_array_type),
137+
),
138+
("token_to_token_ids", IntArrayToIntArrayType),
139+
],
140+
)
141+
class VocabTrie:
142+
"""
143+
VocabTrie: Class for efficient traversal of the vocabulary
144+
Bidirectional mapping between trie node ID and nb_unichar_2_type token
145+
- token_to_token_key: Dict[nb_unichar_2_array_type, int]
146+
- token_key_to_token: Dict[int, nb_unichar_2_array_type]
147+
Allow retrieval of children in trie
148+
- token_key_to_child_token_keys: Dict[int, int64[:]]
149+
Allow retrieval of of token_ids for a given token
150+
- token_to_token_ids: Dict[nb_unichar_2_array_type, int64[:]]
151+
Trie structure:
152+
Only members of the vocabulary are included as nodes, no intermediates.
153+
Structured to guarantee that recursive calls to get_children()
154+
will return every token once, only once.
155+
Given a vocabulary of ["a", "ab", "abc", "ac", "ace", "apple"],
156+
the children of "a" are "ab", "ac", "apple".
157+
"abc" and "ace" are excluded because they have intermediate parents in the vocabulary.
158+
"""
159+
160+
def __init__(
161+
self,
162+
all_token_transitions: List[Sequence[int]],
163+
vocabulary: List[Tuple[str, Sequence[int]]],
164+
):
165+
self.token_to_token_key = IntArrayDict(
166+
typed.Dict.empty(types.intp, types.int64)
167+
)
168+
self.token_key_to_token = typed.Dict.empty(
169+
key_type=types.int64, value_type=nb_int64_array_type
170+
)
171+
self.token_key_to_child_token_keys = typed.Dict.empty(
172+
key_type=types.int64, value_type=nb_int64_array_type
173+
)
174+
self.token_to_token_ids = IntArrayDict(
175+
typed.Dict.empty(types.intp, nb_int64_array_type)
176+
)
177+
178+
self._insert(all_token_transitions, vocabulary)
179+
180+
def _insert(
181+
self,
182+
all_token_transitions: List[Sequence[int]],
183+
vocabulary: List[Tuple[str, Sequence[int]]],
184+
) -> None:
185+
# Initialize an empty array for the root token key to store child token keys
186+
self.token_key_to_child_token_keys[-1] = np.empty((0,), types.int64)
187+
188+
# It's necessary to insert shorter transition sequences (prefixes) first
189+
sorted_idx_transition_seq = sorted(
190+
enumerate(all_token_transitions), key=lambda x: len(x[1])
191+
)
192+
193+
for idx, token_transitions in sorted_idx_transition_seq:
194+
token_ids = vocabulary[idx][1]
195+
if token_transitions not in self.token_to_token_key:
196+
# create bimapping between token and token_key (tokens trie node key)
197+
self.token_to_token_key[token_transitions] = idx
198+
self.token_key_to_token[idx] = token_transitions
199+
200+
# find parent token key
201+
parent_token_key = -1 # root token
202+
for i in range(len(token_transitions) - 1, -1, -1):
203+
prefix_token = token_transitions[:i]
204+
205+
if prefix_token in self.token_to_token_key:
206+
parent_token_key = self.token_to_token_key[prefix_token]
207+
break
208+
# map parent token to current token
209+
self.token_key_to_child_token_keys[parent_token_key] = np.append(
210+
self.token_key_to_child_token_keys[parent_token_key],
211+
np.array([idx]),
212+
)
213+
214+
# map current token to empty list of children
215+
self.token_key_to_child_token_keys[idx] = np.empty((0,), types.int64)
216+
217+
# set current tokens token ids
218+
self.token_to_token_ids[token_transitions] = token_ids
219+
220+
else:
221+
# if exists, append to current tokens token ids
222+
self.token_to_token_ids[token_transitions] = np.append(
223+
self.token_to_token_ids[token_transitions], token_ids
224+
)
225+
226+
def get_children(self, token_transitions: Optional[Sequence[int]] = None):
227+
"""
228+
Get the token_ids of all children for the given token_id.
229+
If token_id is None, get the root children.
230+
"""
231+
if token_transitions is None:
232+
token_key = -1
233+
else:
234+
token_key = self.token_to_token_key[token_transitions]
235+
236+
child_token_keys = self.token_key_to_child_token_keys[token_key]
237+
238+
return [self.token_key_to_token[token_key] for token_key in child_token_keys]
239+
240+
def get_token_ids(self, token):
241+
return self.token_to_token_ids[token]

0 commit comments

Comments
 (0)