1
+ from collections import defaultdict
2
+ from copy import copy , deepcopy
1
3
from dataclasses import dataclass
2
4
from typing import (
3
5
TYPE_CHECKING ,
@@ -69,6 +71,9 @@ class Guide(Protocol):
69
71
70
72
"""
71
73
74
+ start_state : int = 0
75
+ final_state : int = - 1
76
+
72
77
def get_next_instruction (self , state : int ) -> Instruction :
73
78
...
74
79
@@ -82,11 +87,39 @@ def copy(self) -> "Guide":
82
87
...
83
88
84
89
85
- class StopAtEOSGuide ( Guide ) :
86
- """Guide to generate tokens until the EOS token has been generated. """
90
+ class TokenHealerMixin :
91
+ """Class used to add the token align feature to a Guide """
87
92
88
- final_state = 1
89
- start_state = 0
93
+ states_to_token_maps : Dict [int , Dict [int , int ]]
94
+ tokenizer : "Tokenizer"
95
+
96
+ def align_prompt_tokens (self , prompt : str ) -> str :
97
+ """Update the states_to_token_maps and return the aligned prompt"""
98
+ token_ids , _ = self .tokenizer .encode (prompt )
99
+ (
100
+ aligned_token_ids ,
101
+ aligned_states_to_token_maps ,
102
+ ) = align_tokens_states_to_token_maps (
103
+ token_ids .tolist ()[0 ],
104
+ self .tokenizer .vocabulary ,
105
+ deepcopy (self .states_to_token_maps ),
106
+ )
107
+ aligned_prompt = self .tokenizer .decode ([aligned_token_ids ])[0 ]
108
+ # some models do not accept an empty string as a prompt
109
+ # if token alignement would remove all tokens, do not apply it
110
+ if not aligned_prompt :
111
+ return prompt
112
+ self .states_to_token_maps = aligned_states_to_token_maps
113
+ if hasattr (self , "_cache_state_to_token_tensor" ):
114
+ self ._cache_state_to_token_tensor ()
115
+ # remove leading whitespace if added by the tokenizer
116
+ if aligned_prompt [0 ] == " " and prompt [0 ] != " " :
117
+ return aligned_prompt [1 :]
118
+ return aligned_prompt
119
+
120
+
121
+ class StopAtEOSGuide (Guide , TokenHealerMixin ):
122
+ """Guide to generate tokens until the EOS token has been generated."""
90
123
91
124
def __init__ (self , tokenizer : "Tokenizer" ):
92
125
"""Initialize the generation guide.
@@ -95,25 +128,37 @@ def __init__(self, tokenizer: "Tokenizer"):
95
128
The logit generator used to generate the next token.
96
129
97
130
"""
98
- self .eos_token_id = tokenizer .eos_token_id
99
- self .vocabulary = tokenizer .vocabulary .values ()
131
+ self .tokenizer = tokenizer
132
+ self .states_to_token_maps = self .create_states_to_tokens_map ()
133
+
134
+ def create_states_to_tokens_map (self ) -> Dict [int , Dict [int , int ]]:
135
+ """Create the states_to_tokens_map. All tokens lead to the starting
136
+ state, except for the eos_token that leads to the final state."""
137
+ return {
138
+ self .start_state : {
139
+ token_id : self .start_state
140
+ if token_id != self .tokenizer .eos_token_id
141
+ else self .final_state
142
+ for token_id in self .tokenizer .vocabulary .values ()
143
+ }
144
+ }
100
145
101
146
def get_next_instruction (self , state : int ) -> Instruction :
102
147
if self .is_final_state (state ):
103
- return Write ([self .eos_token_id ])
104
- return Generate (None )
148
+ return Write ([self .tokenizer . eos_token_id ])
149
+ return Generate (list ( self . states_to_token_maps [ state ]. keys ()) )
105
150
106
151
def get_next_state (self , state : int , token_id : int ) -> int :
107
- if token_id == self .eos_token_id or state == self . final_state :
152
+ if self .is_final_state ( state ) :
108
153
return self .final_state
109
154
110
- return self .start_state
155
+ return self .states_to_token_maps [ state ][ token_id ]
111
156
112
157
def is_final_state (self , state : int ):
113
158
return state == self .final_state
114
159
115
160
def copy (self ):
116
- return self
161
+ return copy ( self )
117
162
118
163
119
164
@cache ()
@@ -171,20 +216,20 @@ def create_states_mapping(
171
216
return states_to_token_maps , empty_token_ids , regex_fsm .finals
172
217
173
218
174
- class RegexGuide (Guide ):
219
+ class RegexGuide (Guide , TokenHealerMixin ):
175
220
"""Guide to generate text in the language of a regular expression."""
176
221
177
- initial_state = 0
222
+ states_to_token_mask : Dict [ int , torch . Tensor ]
178
223
179
224
def __init__ (self , regex_string : str , tokenizer : "Tokenizer" ):
225
+ self .tokenizer = tokenizer
180
226
(
181
227
self .states_to_token_maps ,
182
228
self .empty_token_ids ,
183
229
fsm_finals ,
184
230
) = create_states_mapping (regex_string , tokenizer )
185
- self .eos_token_id = tokenizer .eos_token_id
186
- self .final_states = fsm_finals | {- 1 }
187
231
self ._cache_state_to_token_tensor ()
232
+ self .final_states = fsm_finals | {self .final_state }
188
233
189
234
def get_next_instruction (self , state : int ) -> Instruction :
190
235
"""Return the next instruction for guided generation.
@@ -211,7 +256,7 @@ def get_next_instruction(self, state: int) -> Instruction:
211
256
"""
212
257
next_tokens_mask = self .states_to_token_mask .get (state )
213
258
if next_tokens_mask is None :
214
- return Write (torch .tensor ([self .eos_token_id ]))
259
+ return Write (torch .tensor ([self .tokenizer . eos_token_id ]))
215
260
216
261
return Generate (next_tokens_mask )
217
262
@@ -233,13 +278,16 @@ def get_next_state(self, state: int, token_id: int) -> int:
233
278
The new state of the guide.
234
279
235
280
"""
236
- if token_id == self .eos_token_id or state not in self .states_to_token_maps :
237
- return - 1
281
+ if (
282
+ token_id == self .tokenizer .eos_token_id
283
+ or state not in self .states_to_token_maps
284
+ ):
285
+ return self .final_state
238
286
239
287
last_token_to_end_state = self .states_to_token_maps [state ]
240
288
next_state = last_token_to_end_state .get (token_id )
241
289
if next_state is None :
242
- next_state = - 1
290
+ next_state = self . final_state
243
291
244
292
return next_state
245
293
@@ -278,11 +326,11 @@ def create_states_mapping_from_interegular_fsm(
278
326
from_interegular_instance .states_to_token_maps ,
279
327
from_interegular_instance .empty_token_ids ,
280
328
) = create_states_mapping_from_interegular_fsm (interegular_fsm )
281
- from_interegular_instance .eos_token_id = tokenizer . eos_token_id
329
+ from_interegular_instance .tokenizer = tokenizer
282
330
from_interegular_instance ._cache_state_to_token_tensor ()
283
331
return from_interegular_instance
284
332
285
- def _cache_state_to_token_tensor (self ):
333
+ def _cache_state_to_token_tensor (self ) -> None :
286
334
"""
287
335
cache state -> token int tensor
288
336
this increases performance of mask construction substantially
@@ -297,7 +345,7 @@ def is_final_state(self, state: int) -> bool:
297
345
return state in self .final_states
298
346
299
347
def copy (self ):
300
- return self
348
+ return copy ( self )
301
349
302
350
303
351
class CFGGuide (Guide ):
@@ -331,9 +379,6 @@ def __init__(self, cfg_string: str, tokenizer):
331
379
self .proposal_last : List [int ] = []
332
380
self .regex_fsm_last : RegexGuide
333
381
334
- self .start_state = 0
335
- self .final_state = - 1
336
-
337
382
def get_next_instruction (self , state : int ) -> Instruction :
338
383
"""Generate an instruction for the next step.
339
384
@@ -475,3 +520,163 @@ def is_final_state(self, state: int) -> bool:
475
520
def copy (self ) -> "CFGGuide" :
476
521
"""Create a copy of the FSM."""
477
522
return CFGGuide (self .cfg_string , self .tokenizer )
523
+
524
+
525
+ def align_tokens_states_to_token_maps (
526
+ token_ids : List [int ],
527
+ vocabulary : Dict [str , int ],
528
+ states_to_token_maps : Dict [int , Dict [int , int ]],
529
+ ) -> Tuple [List [int ], Dict [int , Dict [int , int ]]]:
530
+ """Apply token alignment to the provided prompt tokens and attention masks given the
531
+ states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
532
+ states_to_token_maps. You can find an explanation from Guidance on why token healing
533
+ is necessary here:
534
+ https://github.com/guidance-ai/guidance/blob/main/notebooks/tutorials/token_healing.ipynb
535
+ """
536
+ crossing_tokens = find_crossing_tokens (token_ids , vocabulary )
537
+ valid_crossing_tokens = get_crossing_tokens_target_states (
538
+ states_to_token_maps , crossing_tokens , token_ids , vocabulary
539
+ )
540
+ if not valid_crossing_tokens :
541
+ return token_ids , states_to_token_maps
542
+ (
543
+ states_to_token_maps ,
544
+ number_cropped_tokens ,
545
+ ) = add_crossing_tokens_states_to_tokens_map (
546
+ states_to_token_maps , token_ids , valid_crossing_tokens
547
+ )
548
+ return (
549
+ token_ids [:- number_cropped_tokens ],
550
+ states_to_token_maps ,
551
+ )
552
+
553
+
554
+ def find_crossing_tokens (
555
+ token_ids : List [int ], vocabulary : Dict [str , int ]
556
+ ) -> Dict [int , List [int ]]:
557
+ """Find the tokens that could replace one or more tokens at the end of token_ids
558
+ while conserving the same intial text (and extending it by at least one character).
559
+ Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
560
+ """
561
+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
562
+ len_token_ids = len (token_ids )
563
+ max_length_token_text = max (len (item ) for item in vocabulary .keys ())
564
+ characters_considered = ""
565
+ crossing_tokens_map = {}
566
+
567
+ for index , token_id in enumerate (reversed (token_ids )):
568
+ characters_considered = reversed_vocabulary [token_id ] + characters_considered
569
+ if len (characters_considered ) >= max_length_token_text :
570
+ break
571
+ crossing_token_ids = [
572
+ token_id
573
+ for text , token_id in vocabulary .items ()
574
+ if text .startswith (characters_considered )
575
+ and len (text ) > len (characters_considered )
576
+ ]
577
+ if crossing_token_ids :
578
+ crossing_tokens_map [len_token_ids - index - 1 ] = crossing_token_ids
579
+
580
+ return crossing_tokens_map
581
+
582
+
583
+ def get_crossing_tokens_target_states (
584
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
585
+ crossing_tokens : Dict [int , List [int ]],
586
+ prompt_token_ids : List [int ],
587
+ vocabulary : Dict [str , int ],
588
+ ) -> Dict [int , Dict [int , int ]]:
589
+ """For each crossing token associated to an index, check that the characters after the boundary
590
+ match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
591
+ provided indexes, the associated valid tokens with the state they would lead to.
592
+ """
593
+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
594
+ prompt_token_texts = [
595
+ reversed_vocabulary [token_id ] for token_id in prompt_token_ids
596
+ ]
597
+
598
+ valid_crossing_tokens : Dict [int , Dict [int , int ]] = defaultdict (dict )
599
+ for pos , tokens in crossing_tokens .items ():
600
+ for token in tokens :
601
+ is_valid = True
602
+ characters = reversed_vocabulary [token ]
603
+ characters_before_border = "" .join (prompt_token_texts [pos :])
604
+ characters_after_border = characters [len (characters_before_border ) :]
605
+ state = 0
606
+ for char in characters_after_border :
607
+ char_token = vocabulary .get (char )
608
+ try :
609
+ state = states_to_tokens_map [state ][char_token ] # type: ignore
610
+ except KeyError :
611
+ is_valid = False
612
+ break
613
+ if is_valid :
614
+ valid_crossing_tokens [pos ][token ] = state
615
+
616
+ return valid_crossing_tokens
617
+
618
+
619
+ def add_crossing_tokens_states_to_tokens_map (
620
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
621
+ prompt_token_ids : List [int ],
622
+ crossing_tokens_map : Dict [int , Dict [int , int ]],
623
+ ) -> Tuple [Dict [int , Dict [int , int ]], int ]:
624
+ """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
625
+ the starting state of the fsm as we would include some characters at the end of the prompt in
626
+ the states_to_tokens_map.
627
+ Attention! the starting state of the states_to_tokens_map provided must be 0.
628
+ Return the updated states_to_tokens_map and the number of cropped tokens/additional states
629
+ """
630
+ if not crossing_tokens_map :
631
+ return states_to_tokens_map , 0
632
+ first_crossing_token_pos = min (
633
+ [key for key , value in crossing_tokens_map .items () if value ]
634
+ )
635
+ number_additional_states = len (prompt_token_ids ) - first_crossing_token_pos
636
+ highest_state = max (
637
+ max (states_to_tokens_map .keys ()),
638
+ max (max (items .values ()) for items in states_to_tokens_map .values ()),
639
+ )
640
+
641
+ for i in range (number_additional_states ):
642
+ # add the tokens that was originally part of the prompt
643
+ if i == number_additional_states - 1 :
644
+ states_to_tokens_map [highest_state + 1 + i ] = {
645
+ prompt_token_ids [first_crossing_token_pos + i ]: 0
646
+ }
647
+ else :
648
+ states_to_tokens_map [highest_state + 1 + i ] = {
649
+ prompt_token_ids [first_crossing_token_pos + i ]: highest_state + 2 + i
650
+ }
651
+ # add the crossing tokens
652
+ crossing_tokens = crossing_tokens_map .get (first_crossing_token_pos + i )
653
+ if crossing_tokens :
654
+ for token , target_state in crossing_tokens .items ():
655
+ states_to_tokens_map [highest_state + 1 + i ][token ] = target_state
656
+
657
+ # set the id of our new initial state to 0
658
+ states_to_tokens_map = swap_state_ids_states_to_tokens_map (
659
+ states_to_tokens_map , highest_state + 1 , 0
660
+ )
661
+ return states_to_tokens_map , number_additional_states
662
+
663
+
664
+ def swap_state_ids_states_to_tokens_map (
665
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
666
+ first_state_id : int ,
667
+ second_state_id : int ,
668
+ ) -> Dict [int , Dict [int , int ]]:
669
+ """Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
670
+ first_state_transitions = states_to_tokens_map .pop (first_state_id )
671
+ second_state_transitions = states_to_tokens_map .pop (second_state_id )
672
+ states_to_tokens_map [first_state_id ] = second_state_transitions
673
+ states_to_tokens_map [second_state_id ] = first_state_transitions
674
+
675
+ for transitions in states_to_tokens_map .values ():
676
+ for token , target_state_id in list (transitions .items ()):
677
+ if target_state_id == first_state_id :
678
+ transitions [token ] = second_state_id
679
+ elif target_state_id == second_state_id :
680
+ transitions [token ] = first_state_id
681
+
682
+ return states_to_tokens_map
0 commit comments