1
- from typing import TYPE_CHECKING , List , NewType , Protocol , Tuple
1
+ from collections import defaultdict
2
+ from copy import deepcopy
3
+ from typing import TYPE_CHECKING , Dict , List , NewType , Protocol , Tuple
2
4
3
5
import interegular
6
+ import torch
4
7
from lark import Lark
5
8
6
9
# from outlines.fsm.parsing import PartialLark
@@ -22,6 +25,11 @@ def is_final_state(self, state: FSMState) -> bool:
22
25
"""Determine whether the current state of the FSM is a final state."""
23
26
return state == self .final_state
24
27
28
+ def align_prompt_tokens (
29
+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
30
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
31
+ ...
32
+
25
33
def allowed_token_ids (self , state : FSMState ) -> List [int ]:
26
34
...
27
35
@@ -37,13 +45,41 @@ class StopAtEosFSM(FSM):
37
45
38
46
def __init__ (self , tokenizer : "Tokenizer" ):
39
47
self .eos_token_id = tokenizer .eos_token_id
40
- self .vocabulary = tokenizer .vocabulary .values ()
48
+ self .vocabulary = tokenizer .vocabulary
49
+ self .tokenizer = tokenizer
50
+ self .states_to_token_maps = self .create_states_to_tokens_map ()
51
+
52
+ def create_states_to_tokens_map (self ) -> Dict [int , Dict [int , int ]]:
53
+ """Create the states_to_tokens_map. All tokens from the starting state lead
54
+ to itself, except for the eos_token that leads to the final state."""
55
+ return {
56
+ self .first_state : {
57
+ token_id : self .first_state
58
+ if token_id != self .eos_token_id
59
+ else self .final_state
60
+ for token_id in self .vocabulary .values ()
61
+ }
62
+ }
63
+
64
+ def align_prompt_tokens (
65
+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
66
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
67
+ """Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
68
+ (
69
+ token_ids ,
70
+ attention_masks ,
71
+ self .states_to_token_maps ,
72
+ ) = align_tokens_states_to_token_maps (
73
+ token_ids , attention_masks , self .vocabulary , self .states_to_token_maps
74
+ )
75
+ return token_ids , attention_masks
41
76
42
77
def allowed_token_ids (self , state : FSMState ) -> List [int ]:
43
78
"""Generate a list of allowed tokens for the next step.
44
79
45
- When in the initial state we allow every token to be generated.
46
80
In the final state the only allowed token is `stop_token_id`.
81
+ Otherwise we allow the valid transitions tokens corresponding to
82
+ the current state of the states_to_token_maps
47
83
48
84
Parameters
49
85
----------
@@ -57,14 +93,13 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
57
93
"""
58
94
if self .is_final_state (state ):
59
95
return [self .eos_token_id ]
60
- return list (self .vocabulary )
96
+ return list (self .states_to_token_maps [ state ]. keys () )
61
97
62
98
def next_state (self , state : FSMState , token_id : int ) -> FSMState :
63
99
"""Update the state of the FSM.
64
100
65
- The FSM stays in the initial state `0` unless the specified stop token
66
- has been generated or the maximum number of tokens has been reached. In
67
- which case the FSM moves to the final state `-1`.
101
+ The FSM transitions from a state to the other through the
102
+ states_to_token_maps until the final state is reached.
68
103
69
104
Parameters
70
105
----------
@@ -78,14 +113,14 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
78
113
The new state of the FSM.
79
114
80
115
"""
81
- if token_id == self .eos_token_id :
116
+ if self .is_final_state ( state ) :
82
117
return self .final_state
83
118
84
- return self .first_state
119
+ return FSMState ( self .states_to_token_maps [ state ][ token_id ])
85
120
86
121
def copy (self ) -> "StopAtEosFSM" :
87
122
"""Create a copy of the FSM."""
88
- return self
123
+ return deepcopy ( self )
89
124
90
125
91
126
class RegexFSM (FSM ):
@@ -121,9 +156,22 @@ def create_states_mapping(
121
156
self .states_to_token_maps , self .empty_token_ids = create_states_mapping (
122
157
regex_string , tuple (sorted (tokenizer .vocabulary .items ()))
123
158
)
124
- self .vocabulary = tokenizer .vocabulary . values ()
159
+ self .vocabulary = tokenizer .vocabulary
125
160
self .eos_token_id = tokenizer .eos_token_id
126
161
162
+ def align_prompt_tokens (
163
+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
164
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
165
+ """Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
166
+ (
167
+ token_ids ,
168
+ attention_masks ,
169
+ self .states_to_token_maps ,
170
+ ) = align_tokens_states_to_token_maps (
171
+ token_ids , attention_masks , self .vocabulary , self .states_to_token_maps
172
+ )
173
+ return token_ids , attention_masks
174
+
127
175
def allowed_token_ids (self , state : FSMState ) -> List [int ]:
128
176
"""Generate a list of allowed tokens for the next step.
129
177
@@ -184,7 +232,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
184
232
185
233
def copy (self ) -> "RegexFSM" :
186
234
"""Create a copy of the FSM."""
187
- return self
235
+ return deepcopy ( self )
188
236
189
237
190
238
class CFGFSM (FSM ):
@@ -218,6 +266,12 @@ def __init__(self, cfg_string: str, tokenizer):
218
266
self .proposal_last : List [int ] = []
219
267
self .regex_fsm_last : RegexFSM
220
268
269
+ def align_prompt_tokens (
270
+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
271
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
272
+ """Not applicable to this type of FSM"""
273
+ return token_ids , attention_masks
274
+
221
275
def allowed_token_ids (self , state : FSMState ) -> List [int ]:
222
276
"""Generate a list of allowed tokens for the next step.
223
277
@@ -333,3 +387,162 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
333
387
def copy (self ) -> "CFGFSM" :
334
388
"""Create a copy of the FSM."""
335
389
return CFGFSM (self .cfg_string , self .tokenizer )
390
+
391
+
392
+ def align_tokens_states_to_token_maps (
393
+ token_ids : torch .Tensor ,
394
+ attention_masks : torch .Tensor ,
395
+ vocabulary : Dict [str , int ],
396
+ states_to_token_maps : Dict [int , Dict [int , int ]],
397
+ ) -> Tuple [torch .Tensor , torch .Tensor , Dict [int , Dict [int , int ]]]:
398
+ """Apply token alignment to the provided prompt tokens and attention masks given the
399
+ states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
400
+ states_to_token_maps"""
401
+ prompt_token_ids = token_ids .tolist ()
402
+ crossing_tokens = find_crossing_tokens (prompt_token_ids , vocabulary )
403
+ valid_crossing_tokens = get_crossing_tokens_target_states (
404
+ states_to_token_maps , crossing_tokens , prompt_token_ids , vocabulary
405
+ )
406
+ if not valid_crossing_tokens :
407
+ return token_ids , attention_masks , states_to_token_maps
408
+ (
409
+ states_to_token_maps ,
410
+ number_cropped_tokens ,
411
+ ) = add_crossing_tokens_states_to_tokens_map (
412
+ states_to_token_maps , prompt_token_ids , valid_crossing_tokens
413
+ )
414
+ return (
415
+ token_ids [:- number_cropped_tokens ],
416
+ attention_masks [:- number_cropped_tokens ],
417
+ states_to_token_maps ,
418
+ )
419
+
420
+
421
+ def find_crossing_tokens (
422
+ token_ids : List [int ], vocabulary : Dict [str , int ]
423
+ ) -> Dict [int , List [int ]]:
424
+ """Find the tokens that could replace one or more tokens at the end of token_ids
425
+ while conserving the same intial text (and extending it by at least one character).
426
+ Return a dictionary with, for the indexes in the token_ids, the associated crossing tokens.
427
+ """
428
+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
429
+ len_token_ids = len (token_ids )
430
+ max_length_token_text = max (len (item ) for item in vocabulary .keys ())
431
+ characters_considered = ""
432
+ crossing_tokens_map = {}
433
+
434
+ for index , token_id in enumerate (reversed (token_ids )):
435
+ characters_considered = reversed_vocabulary [token_id ] + characters_considered
436
+ if len (characters_considered ) >= max_length_token_text :
437
+ break
438
+ crossing_token_ids = [
439
+ token_id
440
+ for text , token_id in vocabulary .items ()
441
+ if text .startswith (characters_considered )
442
+ and len (text ) > len (characters_considered )
443
+ ]
444
+ crossing_tokens_map [len_token_ids - index - 1 ] = crossing_token_ids
445
+
446
+ return crossing_tokens_map
447
+
448
+
449
+ def get_crossing_tokens_target_states (
450
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
451
+ crossing_tokens : Dict [int , List [int ]],
452
+ prompt_token_ids : List [int ],
453
+ vocabulary : Dict [str , int ],
454
+ ) -> Dict [int , Dict [int , int ]]:
455
+ """For each crossing token associated to an index, check that the characters after the boundary
456
+ match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
457
+ provided indexes, the associated valid tokens with the state they would lead to.
458
+ """
459
+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
460
+ prompt_token_texts = [
461
+ reversed_vocabulary [token_id ] for token_id in prompt_token_ids
462
+ ]
463
+
464
+ valid_crossing_tokens : Dict [int , Dict [int , int ]] = defaultdict (dict )
465
+ for pos , tokens in crossing_tokens .items ():
466
+ for token in tokens :
467
+ is_valid = True
468
+ characters = reversed_vocabulary [token ]
469
+ characters_before_border = "" .join (prompt_token_texts [pos :])
470
+ characters_after_border = characters [len (characters_before_border ) :]
471
+ state = 0
472
+ for char in characters_after_border :
473
+ char_token = vocabulary .get (char )
474
+ try :
475
+ state = states_to_tokens_map [state ][char_token ] # type: ignore
476
+ except KeyError :
477
+ is_valid = False
478
+ break
479
+ if is_valid :
480
+ valid_crossing_tokens [pos ][token ] = state
481
+
482
+ return valid_crossing_tokens
483
+
484
+
485
+ def add_crossing_tokens_states_to_tokens_map (
486
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
487
+ prompt_token_ids : List [int ],
488
+ crossing_tokens_map : Dict [int , Dict [int , int ]],
489
+ ) -> Tuple [Dict [int , Dict [int , int ]], int ]:
490
+ """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
491
+ the starting state of the fsm as we would include some characters at the end of the prompt in
492
+ the states_to_tokens_map.
493
+ Attention! the starting state of the states_to_tokens_map provided must be 0.
494
+ Return the updated states_to_tokens_map and the number of cropped tokens/additional states
495
+ """
496
+ if not crossing_tokens_map :
497
+ return states_to_tokens_map , 0
498
+ first_crossing_token_pos = min (
499
+ [key for key , value in crossing_tokens_map .items () if value ]
500
+ )
501
+ number_additional_states = len (prompt_token_ids ) - first_crossing_token_pos
502
+ highest_state = max (
503
+ max (states_to_tokens_map .keys ()),
504
+ max (max (items .values ()) for items in states_to_tokens_map .values ()),
505
+ )
506
+
507
+ for i in range (number_additional_states ):
508
+ # add the tokens that was originally part of the prompt
509
+ if i == number_additional_states - 1 :
510
+ states_to_tokens_map [highest_state + 1 + i ] = {
511
+ prompt_token_ids [first_crossing_token_pos + i ]: 0
512
+ }
513
+ else :
514
+ states_to_tokens_map [highest_state + 1 + i ] = {
515
+ prompt_token_ids [first_crossing_token_pos + i ]: highest_state + 2 + i
516
+ }
517
+ # add the crossing tokens
518
+ crossing_tokens = crossing_tokens_map .get (first_crossing_token_pos + i )
519
+ if crossing_tokens :
520
+ for token , target_state in crossing_tokens .items ():
521
+ states_to_tokens_map [highest_state + 1 + i ][token ] = target_state
522
+
523
+ # set the id of our new initial state to 0
524
+ states_to_tokens_map = swap_state_ids_states_to_tokens_map (
525
+ states_to_tokens_map , highest_state + 1 , 0
526
+ )
527
+ return states_to_tokens_map , number_additional_states
528
+
529
+
530
+ def swap_state_ids_states_to_tokens_map (
531
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
532
+ first_state_id : int ,
533
+ second_state_id : int ,
534
+ ) -> Dict [int , Dict [int , int ]]:
535
+ """Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
536
+ first_state_transitions = states_to_tokens_map .pop (first_state_id )
537
+ second_state_transitions = states_to_tokens_map .pop (second_state_id )
538
+ states_to_tokens_map [first_state_id ] = second_state_transitions
539
+ states_to_tokens_map [second_state_id ] = first_state_transitions
540
+
541
+ for transitions in states_to_tokens_map .values ():
542
+ for token , target_state_id in list (transitions .items ()):
543
+ if target_state_id == first_state_id :
544
+ transitions [token ] = second_state_id
545
+ elif target_state_id == second_state_id :
546
+ transitions [token ] = first_state_id
547
+
548
+ return states_to_tokens_map
0 commit comments