1
1
from copy import deepcopy
2
- from typing import TYPE_CHECKING , List , NewType , Protocol
2
+ from typing import TYPE_CHECKING , Dict , List , NewType , Protocol , Tuple
3
3
4
- import cloudpickle
5
4
import interegular
6
5
from lark import Lark
7
6
15
14
16
15
17
16
class FSM (Protocol ):
17
+ def align_prompt_tokens (self , prompt : str ) -> str :
18
+ ...
19
+
18
20
def allowed_token_ids (self , state : FSMState ) -> List [int ]:
19
21
...
20
22
@@ -39,8 +41,23 @@ class StopAtTokenFSM(FSM):
39
41
40
42
def __init__ (self , tokenizer : "Tokenizer" , stop_token_id : int ):
41
43
self .stop_token_id = stop_token_id
42
- self .vocabulary = tokenizer .vocabulary .values ()
43
- self .final_states = {1 }
44
+ self .tokenizer = tokenizer
45
+ self .vocabulary = tokenizer .vocabulary
46
+ self .final_states = {2 }
47
+ self .valid_alignment_tokens : List [int ] = []
48
+
49
+ def align_prompt_tokens (self , prompt : str ) -> str :
50
+ """Remove the last token from the prompt and set the value of self.valid_alignment_tokens"""
51
+ token_ids , _ = self .tokenizer .encode (prompt )
52
+ last_token_id = int (token_ids [0 ][- 1 ])
53
+ last_token_text = self .tokenizer .decode ([last_token_id ])[0 ]
54
+ # select the tokens that start with the text removed from the prompt
55
+ self .valid_alignment_tokens = [
56
+ token
57
+ for text , token in self .vocabulary .items ()
58
+ if text .startswith (last_token_text )
59
+ ]
60
+ return prompt [: - len (last_token_text )]
44
61
45
62
def allowed_token_ids (self , state : FSMState ) -> List [int ]:
46
63
"""Generate a list of allowed tokens for the next step.
@@ -59,7 +76,9 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
59
76
60
77
"""
61
78
if state == 0 :
62
- return list (self .vocabulary )
79
+ return self .valid_alignment_tokens
80
+ elif state == 1 :
81
+ return list (self .vocabulary .values ())
63
82
else :
64
83
return [self .stop_token_id ]
65
84
@@ -83,17 +102,17 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
83
102
84
103
"""
85
104
if token_id == self .stop_token_id :
86
- return FSMState (1 )
105
+ return FSMState (2 )
87
106
88
- return FSMState (0 )
107
+ return FSMState (1 )
89
108
90
109
def is_final_state (self , state : FSMState ) -> bool :
91
110
"""Determine whether the current state of the FSM is a final state."""
92
111
return state in self .final_states
93
112
94
113
def copy (self ) -> "StopAtTokenFSM" :
95
114
"""Create a copy of the FSM."""
96
- return self
115
+ return deepcopy ( self )
97
116
98
117
99
118
class RegexFSM (FSM ):
@@ -122,41 +141,61 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
122
141
- 1
123
142
} # Include the EOS token in final states
124
143
self .tokenizer = tokenizer
125
- self .vocabulary = tokenizer .vocabulary . values ()
144
+ self .vocabulary = tokenizer .vocabulary
126
145
self .end_token_id = tokenizer .eos_token_id
127
146
128
147
def align_prompt_tokens (self , prompt : str ) -> str :
129
148
"""Remove the last token from the prompt and update the states_to_token_maps accordingly"""
130
149
token_ids , _ = self .tokenizer .encode (prompt )
131
150
last_token_id = int (token_ids [0 ][- 1 ])
132
151
last_token_text = self .tokenizer .decode ([last_token_id ])[0 ]
133
- vocabulary = {
134
- self .tokenizer .decode ([token_id ])[0 ]: token_id
135
- for token_id in range (len (self .vocabulary ))
136
- }
137
- starting_state_tokens = {
138
- self .tokenizer .decode ([token_id ])[0 ]: self .states_to_token_maps [0 ][token_id ]
139
- for token_id in self .states_to_token_maps [0 ]
140
- }
141
- # select the tokens that start with the text removed from the prompt and whose text after the
142
- # initial prompt corresponds to that of one of the allowed tokens of the starting state
143
- possible_tokens = {
144
- vocabulary [token_text ]: starting_state_tokens [token_text [len (last_token_text ):]]
145
- for token_text in vocabulary
146
- if (
147
- token_text .startswith (last_token_text )
148
- and starting_state_tokens .get (token_text [len (last_token_text ):])
149
- )
152
+ last_token_length = len (last_token_text )
153
+ # select the tokens that start with the text removed from the prompt
154
+ crossing_tokens = {
155
+ token : text
156
+ for text , token in self .vocabulary .items ()
157
+ if text .startswith (last_token_text )
150
158
}
159
+ # keep only the tokens whose text after the boundary matches the fsm
160
+ valid_tokens_states = self .find_valid_crossing_tokens (
161
+ crossing_tokens , last_token_length
162
+ )
151
163
# update the states_to_token_maps in the following manner:
152
164
# the value of the starting state is assigned to a new state, the starting state is now the
153
- # possible_tokens found above + the last_token we removed (that leads to the new state)
154
- additional_state_id = max (list (self .states_to_token_maps .keys ()) + list (self .final_states )) + 1
165
+ # valid_tokens_states found above
166
+ additional_state_id = (
167
+ max (list (self .states_to_token_maps .keys ()) + list (self .final_states )) + 1
168
+ )
155
169
self .states_to_token_maps [additional_state_id ] = self .states_to_token_maps [0 ]
156
- self .states_to_token_maps [0 ] = {** possible_tokens , last_token_id : additional_state_id }
157
-
158
- return prompt [:- len (last_token_text )]
159
-
170
+ self .states_to_token_maps [0 ] = {}
171
+ for token , state in valid_tokens_states :
172
+ if state == 0 :
173
+ self .states_to_token_maps [0 ][token ] = additional_state_id
174
+ else :
175
+ self .states_to_token_maps [0 ][token ] = state
176
+ return prompt [: - len (last_token_text )]
177
+
178
+ def find_valid_crossing_tokens (
179
+ self , crossing_tokens : Dict [int , str ], last_token_length : int
180
+ ) -> List [Tuple [int , int ]]:
181
+ """For each crossing token, check that the characters after the boundary match the FSM
182
+ and find the state it would lead to. Return the valid tokens with the associated state
183
+ """
184
+ valid_tokens = []
185
+ for token , text in crossing_tokens .items ():
186
+ is_valid = True
187
+ crossing_text = text [last_token_length :]
188
+ state = 0
189
+ for char in crossing_text :
190
+ char_token = self .vocabulary .get (char )
191
+ try :
192
+ state = self .states_to_token_maps [state ][char_token ] # type: ignore
193
+ except KeyError :
194
+ is_valid = False
195
+ break
196
+ if is_valid :
197
+ valid_tokens .append ((token , state ))
198
+ return valid_tokens
160
199
161
200
def allowed_token_ids (self , state : FSMState ) -> List [int ]:
162
201
"""Generate a list of allowed tokens for the next step.
@@ -222,12 +261,7 @@ def is_final_state(self, state: FSMState) -> bool:
222
261
223
262
def copy (self ) -> "RegexFSM" :
224
263
"""Create a copy of the FSM."""
225
- # temporary solution to the problem of unpickleable dict_values
226
- self .vocabulary = cloudpickle .dumps (self .vocabulary )
227
- copy = deepcopy (self )
228
- self .vocabulary = cloudpickle .loads (self .vocabulary )
229
- copy .vocabulary = cloudpickle .loads (copy .vocabulary )
230
- return copy
264
+ return deepcopy (self )
231
265
232
266
233
267
class CFGFSM (FSM ):
@@ -257,6 +291,10 @@ def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
257
291
self .done = False
258
292
self .regex_fsm : RegexFSM
259
293
294
+ def align_prompt_tokens (self , prompt : str ) -> str :
295
+ """Not implemented for CFGFSM"""
296
+ return prompt
297
+
260
298
def _set_next_regex_fsm (self ) -> None :
261
299
"""Use the CFG incremental parser to set the next regex FSM.
262
300
@@ -278,7 +316,6 @@ def _set_next_regex_fsm(self) -> None:
278
316
self .allow_eos = True
279
317
options .add ("" )
280
318
assert len (options ) > 1
281
-
282
319
regex_string = r"(" + r"|" .join ([r"(" + x + r")" for x in options ]) + r")"
283
320
self .regex_fsm = RegexFSM (regex_string , self .tokenizer )
284
321
self .reset_state = True
0 commit comments