1
+ from collections import defaultdict
2
+ from copy import deepcopy
1
3
from dataclasses import dataclass
2
- from typing import TYPE_CHECKING , List , Optional , Protocol , Tuple , Union
4
+ from typing import TYPE_CHECKING , Dict , List , Optional , Protocol , Tuple , Union
3
5
4
6
import interegular
7
+ import torch
5
8
from lark import Lark
6
9
7
10
from outlines import grammars
@@ -67,14 +70,17 @@ def get_next_state(self, state: int, token_id: int) -> int:
67
70
def is_final_state (self , state : int ) -> bool :
68
71
...
69
72
73
+ def align_prompt_tokens (self , prompt : str ) -> str :
74
+ ...
75
+
70
76
def copy (self ) -> "Guide" :
71
77
...
72
78
73
79
74
80
class StopAtEOSGuide (Guide ):
75
81
"""Guide to generate tokens until the EOS token has been generated."""
76
82
77
- final_state = 1
83
+ final_state = - 1
78
84
start_state = 0
79
85
80
86
def __init__ (self , tokenizer : "Tokenizer" ):
@@ -85,24 +91,49 @@ def __init__(self, tokenizer: "Tokenizer"):
85
91
86
92
"""
87
93
self .eos_token_id = tokenizer .eos_token_id
88
- self .vocabulary = tokenizer .vocabulary .values ()
94
+ self .tokenizer = tokenizer
95
+ self .states_to_token_maps = self .create_states_to_tokens_map ()
96
+
97
+ def create_states_to_tokens_map (self ) -> Dict [int , Dict [int , int ]]:
98
+ """Create the states_to_tokens_map. All tokens lead to the starting
99
+ state, except for the eos_token that leads to the final state."""
100
+ return {
101
+ self .start_state : {
102
+ token_id : self .start_state
103
+ if token_id != self .eos_token_id
104
+ else self .final_state
105
+ for token_id in self .tokenizer .vocabulary .values ()
106
+ }
107
+ }
108
+
109
+ def align_prompt_tokens (self , prompt : str ) -> str :
110
+ """Update the states_to_token_maps and return the aligned prompt"""
111
+ token_ids , _ = self .tokenizer .encode (prompt )
112
+ (
113
+ aligned_token_ids ,
114
+ self .states_to_token_maps ,
115
+ ) = align_tokens_states_to_token_maps (
116
+ token_ids [0 ], self .tokenizer .vocabulary , self .states_to_token_maps
117
+ )
118
+ decoded_aligned_token_ids = self .tokenizer .decode (aligned_token_ids )
119
+ return "" .join (decoded_aligned_token_ids )
89
120
90
121
def get_next_instruction (self , state : int ) -> Instruction :
91
122
if self .is_final_state (state ):
92
123
return Write ([self .eos_token_id ])
93
- return Generate (None )
124
+ return Generate (list ( self . states_to_token_maps [ state ]. keys ()) )
94
125
95
126
def get_next_state (self , state : int , token_id : int ) -> int :
96
- if token_id == self .eos_token_id or state == self . final_state :
127
+ if self .is_final_state ( state ) :
97
128
return self .final_state
98
129
99
- return self .start_state
130
+ return self .states_to_token_maps [ state ][ token_id ]
100
131
101
132
def is_final_state (self , state : int ):
102
133
return state == self .final_state
103
134
104
135
def copy (self ):
105
- return self
136
+ return deepcopy ( self )
106
137
107
138
108
139
@cache ()
@@ -143,9 +174,22 @@ def __init__(self, regex_string: str, tokenizer):
143
174
self .empty_token_ids ,
144
175
fsm_finals ,
145
176
) = create_states_mapping (regex_string , tokenizer )
177
+ self .tokenizer = tokenizer
146
178
self .eos_token_id = tokenizer .eos_token_id
147
179
self .final_states = fsm_finals | {- 1 }
148
180
181
+ def align_prompt_tokens (self , prompt : str ) -> str :
182
+ """Update the states_to_token_maps and return the aligned prompt"""
183
+ token_ids , _ = self .tokenizer .encode (prompt )
184
+ (
185
+ aligned_token_ids ,
186
+ self .states_to_token_maps ,
187
+ ) = align_tokens_states_to_token_maps (
188
+ token_ids [0 ], self .tokenizer .vocabulary , self .states_to_token_maps
189
+ )
190
+ decoded_aligned_token_ids = self .tokenizer .decode (aligned_token_ids )
191
+ return "" .join (decoded_aligned_token_ids )
192
+
149
193
def get_next_instruction (self , state : int ) -> Instruction :
150
194
"""Return the next instruction for guided generation.
151
195
@@ -246,7 +290,7 @@ def is_final_state(self, state: int) -> bool:
246
290
return state in self .final_states
247
291
248
292
def copy (self ):
249
- return self
293
+ return deepcopy ( self )
250
294
251
295
252
296
class CFGGuide (Guide ):
@@ -283,6 +327,10 @@ def __init__(self, cfg_string: str, tokenizer):
283
327
self .start_state = 0
284
328
self .final_state = - 1
285
329
330
+ def align_prompt_tokens (self , prompt : str ) -> str :
331
+ """Not applicable to this type of Guide"""
332
+ return prompt
333
+
286
334
def get_next_instruction (self , state : int ) -> Instruction :
287
335
"""Generate an instruction for the next step.
288
336
@@ -424,3 +472,161 @@ def is_final_state(self, state: int) -> bool:
424
472
def copy (self ) -> "CFGGuide" :
425
473
"""Create a copy of the FSM."""
426
474
return CFGGuide (self .cfg_string , self .tokenizer )
475
+
476
+
477
+ def align_tokens_states_to_token_maps (
478
+ token_ids : torch .Tensor ,
479
+ vocabulary : Dict [str , int ],
480
+ states_to_token_maps : Dict [int , Dict [int , int ]],
481
+ ) -> Tuple [torch .Tensor , Dict [int , Dict [int , int ]]]:
482
+ """Apply token alignment to the provided prompt tokens and attention masks given the
483
+ states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
484
+ states_to_token_maps"""
485
+ prompt_token_ids = token_ids .tolist ()
486
+ crossing_tokens = find_crossing_tokens (prompt_token_ids , vocabulary )
487
+ valid_crossing_tokens = get_crossing_tokens_target_states (
488
+ states_to_token_maps , crossing_tokens , prompt_token_ids , vocabulary
489
+ )
490
+ if not valid_crossing_tokens :
491
+ return token_ids , states_to_token_maps
492
+ (
493
+ states_to_token_maps ,
494
+ number_cropped_tokens ,
495
+ ) = add_crossing_tokens_states_to_tokens_map (
496
+ states_to_token_maps , prompt_token_ids , valid_crossing_tokens
497
+ )
498
+ return (
499
+ token_ids [:- number_cropped_tokens ],
500
+ states_to_token_maps ,
501
+ )
502
+
503
+
504
+ def find_crossing_tokens (
505
+ token_ids : List [int ], vocabulary : Dict [str , int ]
506
+ ) -> Dict [int , List [int ]]:
507
+ """Find the tokens that could replace one or more tokens at the end of token_ids
508
+ while conserving the same intial text (and extending it by at least one character).
509
+ Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
510
+ """
511
+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
512
+ len_token_ids = len (token_ids )
513
+ max_length_token_text = max (len (item ) for item in vocabulary .keys ())
514
+ characters_considered = ""
515
+ crossing_tokens_map = {}
516
+
517
+ for index , token_id in enumerate (reversed (token_ids )):
518
+ characters_considered = reversed_vocabulary [token_id ] + characters_considered
519
+ if len (characters_considered ) >= max_length_token_text :
520
+ break
521
+ crossing_token_ids = [
522
+ token_id
523
+ for text , token_id in vocabulary .items ()
524
+ if text .startswith (characters_considered )
525
+ and len (text ) > len (characters_considered )
526
+ ]
527
+ if crossing_token_ids :
528
+ crossing_tokens_map [len_token_ids - index - 1 ] = crossing_token_ids
529
+
530
+ return crossing_tokens_map
531
+
532
+
533
+ def get_crossing_tokens_target_states (
534
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
535
+ crossing_tokens : Dict [int , List [int ]],
536
+ prompt_token_ids : List [int ],
537
+ vocabulary : Dict [str , int ],
538
+ ) -> Dict [int , Dict [int , int ]]:
539
+ """For each crossing token associated to an index, check that the characters after the boundary
540
+ match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
541
+ provided indexes, the associated valid tokens with the state they would lead to.
542
+ """
543
+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
544
+ prompt_token_texts = [
545
+ reversed_vocabulary [token_id ] for token_id in prompt_token_ids
546
+ ]
547
+
548
+ valid_crossing_tokens : Dict [int , Dict [int , int ]] = defaultdict (dict )
549
+ for pos , tokens in crossing_tokens .items ():
550
+ for token in tokens :
551
+ is_valid = True
552
+ characters = reversed_vocabulary [token ]
553
+ characters_before_border = "" .join (prompt_token_texts [pos :])
554
+ characters_after_border = characters [len (characters_before_border ) :]
555
+ state = 0
556
+ for char in characters_after_border :
557
+ char_token = vocabulary .get (char )
558
+ try :
559
+ state = states_to_tokens_map [state ][char_token ] # type: ignore
560
+ except KeyError :
561
+ is_valid = False
562
+ break
563
+ if is_valid :
564
+ valid_crossing_tokens [pos ][token ] = state
565
+
566
+ return valid_crossing_tokens
567
+
568
+
569
+ def add_crossing_tokens_states_to_tokens_map (
570
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
571
+ prompt_token_ids : List [int ],
572
+ crossing_tokens_map : Dict [int , Dict [int , int ]],
573
+ ) -> Tuple [Dict [int , Dict [int , int ]], int ]:
574
+ """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
575
+ the starting state of the fsm as we would include some characters at the end of the prompt in
576
+ the states_to_tokens_map.
577
+ Attention! the starting state of the states_to_tokens_map provided must be 0.
578
+ Return the updated states_to_tokens_map and the number of cropped tokens/additional states
579
+ """
580
+ if not crossing_tokens_map :
581
+ return states_to_tokens_map , 0
582
+ first_crossing_token_pos = min (
583
+ [key for key , value in crossing_tokens_map .items () if value ]
584
+ )
585
+ number_additional_states = len (prompt_token_ids ) - first_crossing_token_pos
586
+ highest_state = max (
587
+ max (states_to_tokens_map .keys ()),
588
+ max (max (items .values ()) for items in states_to_tokens_map .values ()),
589
+ )
590
+
591
+ for i in range (number_additional_states ):
592
+ # add the tokens that was originally part of the prompt
593
+ if i == number_additional_states - 1 :
594
+ states_to_tokens_map [highest_state + 1 + i ] = {
595
+ prompt_token_ids [first_crossing_token_pos + i ]: 0
596
+ }
597
+ else :
598
+ states_to_tokens_map [highest_state + 1 + i ] = {
599
+ prompt_token_ids [first_crossing_token_pos + i ]: highest_state + 2 + i
600
+ }
601
+ # add the crossing tokens
602
+ crossing_tokens = crossing_tokens_map .get (first_crossing_token_pos + i )
603
+ if crossing_tokens :
604
+ for token , target_state in crossing_tokens .items ():
605
+ states_to_tokens_map [highest_state + 1 + i ][token ] = target_state
606
+
607
+ # set the id of our new initial state to 0
608
+ states_to_tokens_map = swap_state_ids_states_to_tokens_map (
609
+ states_to_tokens_map , highest_state + 1 , 0
610
+ )
611
+ return states_to_tokens_map , number_additional_states
612
+
613
+
614
+ def swap_state_ids_states_to_tokens_map (
615
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
616
+ first_state_id : int ,
617
+ second_state_id : int ,
618
+ ) -> Dict [int , Dict [int , int ]]:
619
+ """Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
620
+ first_state_transitions = states_to_tokens_map .pop (first_state_id )
621
+ second_state_transitions = states_to_tokens_map .pop (second_state_id )
622
+ states_to_tokens_map [first_state_id ] = second_state_transitions
623
+ states_to_tokens_map [second_state_id ] = first_state_transitions
624
+
625
+ for transitions in states_to_tokens_map .values ():
626
+ for token , target_state_id in list (transitions .items ()):
627
+ if target_state_id == first_state_id :
628
+ transitions [token ] = second_state_id
629
+ elif target_state_id == second_state_id :
630
+ transitions [token ] = first_state_id
631
+
632
+ return states_to_tokens_map
0 commit comments