1
+ from collections import defaultdict
2
+ from copy import deepcopy
1
3
from dataclasses import dataclass
2
- from typing import TYPE_CHECKING , List , Protocol , Tuple , Union
4
+ from typing import TYPE_CHECKING , Dict , List , Protocol , Tuple , Union
3
5
4
6
import interegular
7
+ import torch
5
8
from lark import Lark
6
9
7
10
from outlines import grammars
@@ -62,11 +65,16 @@ def get_next_state(self, state: int, token_id: int) -> int:
62
65
def is_final_state (self , state : int ) -> bool :
63
66
...
64
67
68
+ def align_prompt_tokens (
69
+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
70
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
71
+ ...
72
+
65
73
66
74
class StopAtEOSGuide (Guide ):
67
75
"""Guide to generate tokens until the EOS token has been generated."""
68
76
69
- final_state = 1
77
+ final_state = - 1
70
78
start_state = 0
71
79
72
80
def __init__ (self , tokenizer : "Tokenizer" ):
@@ -77,24 +85,52 @@ def __init__(self, tokenizer: "Tokenizer"):
77
85
78
86
"""
79
87
self .eos_token_id = tokenizer .eos_token_id
80
- self .vocabulary = tokenizer .vocabulary .values ()
88
+ self .vocabulary = tokenizer .vocabulary
89
+ self .tokenizer = tokenizer
90
+ self .states_to_token_maps = self .create_states_to_tokens_map ()
91
+
92
+ def create_states_to_tokens_map (self ) -> Dict [int , Dict [int , int ]]:
93
+ """Create the states_to_tokens_map. All tokens from the starting state lead
94
+ to itself, except for the eos_token that leads to the final state."""
95
+ return {
96
+ self .start_state : {
97
+ token_id : self .start_state
98
+ if token_id != self .eos_token_id
99
+ else self .final_state
100
+ for token_id in self .vocabulary .values ()
101
+ }
102
+ }
103
+
104
+ def align_prompt_tokens (
105
+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
106
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
107
+ """Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
108
+ (
109
+ token_ids ,
110
+ attention_masks ,
111
+ self .states_to_token_maps ,
112
+ ) = align_tokens_states_to_token_maps (
113
+ token_ids , attention_masks , self .vocabulary , self .states_to_token_maps
114
+ )
115
+ return token_ids , attention_masks
81
116
82
117
def get_next_instruction (self , state : int ) -> Instruction :
83
118
if self .is_final_state (state ):
84
119
return Write ([self .eos_token_id ])
85
- return Generate (list (self .vocabulary ))
120
+
121
+ return Generate (list (self .states_to_token_maps [state ].keys ()))
86
122
87
123
def get_next_state (self , state : int , token_id : int ) -> int :
88
- if token_id == self .eos_token_id or state == self . final_state :
124
+ if self .is_final_state ( state ) :
89
125
return self .final_state
90
126
91
- return self .start_state
127
+ return self .states_to_token_maps [ state ][ token_id ]
92
128
93
129
def is_final_state (self , state : int ):
94
130
return state == self .final_state
95
131
96
132
def copy (self ):
97
- return self
133
+ return deepcopy ( self )
98
134
99
135
100
136
class RegexGuide (Guide ):
@@ -136,10 +172,23 @@ def create_states_mapping(
136
172
) = create_states_mapping (
137
173
regex_string , tuple (sorted (tokenizer .vocabulary .items ()))
138
174
)
139
- self .vocabulary = list ( tokenizer .vocabulary . values ())
175
+ self .vocabulary = tokenizer .vocabulary
140
176
self .eos_token_id = tokenizer .eos_token_id
141
177
self .final_states = fsm_finals | {- 1 }
142
178
179
+ def align_prompt_tokens (
180
+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
181
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
182
+ """Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
183
+ (
184
+ token_ids ,
185
+ attention_masks ,
186
+ self .states_to_token_maps ,
187
+ ) = align_tokens_states_to_token_maps (
188
+ token_ids , attention_masks , self .vocabulary , self .states_to_token_maps
189
+ )
190
+ return token_ids , attention_masks
191
+
143
192
def get_next_instruction (self , state : int ) -> Instruction :
144
193
"""Return the next instruction for guided generation.
145
194
@@ -244,7 +293,7 @@ def is_final_state(self, state: int) -> bool:
244
293
return state in self .final_states
245
294
246
295
def copy (self ):
247
- return self
296
+ return deepcopy ( self )
248
297
249
298
250
299
class CFGGuide (Guide ):
@@ -281,6 +330,12 @@ def __init__(self, cfg_string: str, tokenizer):
281
330
self .start_state = 0
282
331
self .final_state = - 1
283
332
333
+ def align_prompt_tokens (
334
+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
335
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
336
+ """Not applicable to this type of FSM"""
337
+ return token_ids , attention_masks
338
+
284
339
def get_next_instruction (self , state : int ) -> Instruction :
285
340
"""Generate an instruction for the next step.
286
341
@@ -416,3 +471,163 @@ def is_final_state(self, state: int) -> bool:
416
471
def copy (self ) -> "CFGGuide" :
417
472
"""Create a copy of the FSM."""
418
473
return CFGGuide (self .cfg_string , self .tokenizer )
474
+
475
+
476
+ def align_tokens_states_to_token_maps (
477
+ token_ids : torch .Tensor ,
478
+ attention_masks : torch .Tensor ,
479
+ vocabulary : Dict [str , int ],
480
+ states_to_token_maps : Dict [int , Dict [int , int ]],
481
+ ) -> Tuple [torch .Tensor , torch .Tensor , Dict [int , Dict [int , int ]]]:
482
+ """Apply token alignment to the provided prompt tokens and attention masks given the
483
+ states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
484
+ states_to_token_maps"""
485
+ prompt_token_ids = token_ids .tolist ()
486
+ crossing_tokens = find_crossing_tokens (prompt_token_ids , vocabulary )
487
+ valid_crossing_tokens = get_crossing_tokens_target_states (
488
+ states_to_token_maps , crossing_tokens , prompt_token_ids , vocabulary
489
+ )
490
+ if not valid_crossing_tokens :
491
+ return token_ids , attention_masks , states_to_token_maps
492
+ (
493
+ states_to_token_maps ,
494
+ number_cropped_tokens ,
495
+ ) = add_crossing_tokens_states_to_tokens_map (
496
+ states_to_token_maps , prompt_token_ids , valid_crossing_tokens
497
+ )
498
+ return (
499
+ token_ids [:- number_cropped_tokens ],
500
+ attention_masks [:- number_cropped_tokens ],
501
+ states_to_token_maps ,
502
+ )
503
+
504
+
505
+ def find_crossing_tokens (
506
+ token_ids : List [int ], vocabulary : Dict [str , int ]
507
+ ) -> Dict [int , List [int ]]:
508
+ """Find the tokens that could replace one or more tokens at the end of token_ids
509
+ while conserving the same intial text (and extending it by at least one character).
510
+ Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
511
+ """
512
+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
513
+ len_token_ids = len (token_ids )
514
+ max_length_token_text = max (len (item ) for item in vocabulary .keys ())
515
+ characters_considered = ""
516
+ crossing_tokens_map = {}
517
+
518
+ for index , token_id in enumerate (reversed (token_ids )):
519
+ characters_considered = reversed_vocabulary [token_id ] + characters_considered
520
+ if len (characters_considered ) >= max_length_token_text :
521
+ break
522
+ crossing_token_ids = [
523
+ token_id
524
+ for text , token_id in vocabulary .items ()
525
+ if text .startswith (characters_considered )
526
+ and len (text ) > len (characters_considered )
527
+ ]
528
+ if crossing_token_ids :
529
+ crossing_tokens_map [len_token_ids - index - 1 ] = crossing_token_ids
530
+
531
+ return crossing_tokens_map
532
+
533
+
534
+ def get_crossing_tokens_target_states (
535
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
536
+ crossing_tokens : Dict [int , List [int ]],
537
+ prompt_token_ids : List [int ],
538
+ vocabulary : Dict [str , int ],
539
+ ) -> Dict [int , Dict [int , int ]]:
540
+ """For each crossing token associated to an index, check that the characters after the boundary
541
+ match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
542
+ provided indexes, the associated valid tokens with the state they would lead to.
543
+ """
544
+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
545
+ prompt_token_texts = [
546
+ reversed_vocabulary [token_id ] for token_id in prompt_token_ids
547
+ ]
548
+
549
+ valid_crossing_tokens : Dict [int , Dict [int , int ]] = defaultdict (dict )
550
+ for pos , tokens in crossing_tokens .items ():
551
+ for token in tokens :
552
+ is_valid = True
553
+ characters = reversed_vocabulary [token ]
554
+ characters_before_border = "" .join (prompt_token_texts [pos :])
555
+ characters_after_border = characters [len (characters_before_border ) :]
556
+ state = 0
557
+ for char in characters_after_border :
558
+ char_token = vocabulary .get (char )
559
+ try :
560
+ state = states_to_tokens_map [state ][char_token ] # type: ignore
561
+ except KeyError :
562
+ is_valid = False
563
+ break
564
+ if is_valid :
565
+ valid_crossing_tokens [pos ][token ] = state
566
+
567
+ return valid_crossing_tokens
568
+
569
+
570
+ def add_crossing_tokens_states_to_tokens_map (
571
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
572
+ prompt_token_ids : List [int ],
573
+ crossing_tokens_map : Dict [int , Dict [int , int ]],
574
+ ) -> Tuple [Dict [int , Dict [int , int ]], int ]:
575
+ """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
576
+ the starting state of the fsm as we would include some characters at the end of the prompt in
577
+ the states_to_tokens_map.
578
+ Attention! the starting state of the states_to_tokens_map provided must be 0.
579
+ Return the updated states_to_tokens_map and the number of cropped tokens/additional states
580
+ """
581
+ if not crossing_tokens_map :
582
+ return states_to_tokens_map , 0
583
+ first_crossing_token_pos = min (
584
+ [key for key , value in crossing_tokens_map .items () if value ]
585
+ )
586
+ number_additional_states = len (prompt_token_ids ) - first_crossing_token_pos
587
+ highest_state = max (
588
+ max (states_to_tokens_map .keys ()),
589
+ max (max (items .values ()) for items in states_to_tokens_map .values ()),
590
+ )
591
+
592
+ for i in range (number_additional_states ):
593
+ # add the tokens that was originally part of the prompt
594
+ if i == number_additional_states - 1 :
595
+ states_to_tokens_map [highest_state + 1 + i ] = {
596
+ prompt_token_ids [first_crossing_token_pos + i ]: 0
597
+ }
598
+ else :
599
+ states_to_tokens_map [highest_state + 1 + i ] = {
600
+ prompt_token_ids [first_crossing_token_pos + i ]: highest_state + 2 + i
601
+ }
602
+ # add the crossing tokens
603
+ crossing_tokens = crossing_tokens_map .get (first_crossing_token_pos + i )
604
+ if crossing_tokens :
605
+ for token , target_state in crossing_tokens .items ():
606
+ states_to_tokens_map [highest_state + 1 + i ][token ] = target_state
607
+
608
+ # set the id of our new initial state to 0
609
+ states_to_tokens_map = swap_state_ids_states_to_tokens_map (
610
+ states_to_tokens_map , highest_state + 1 , 0
611
+ )
612
+ return states_to_tokens_map , number_additional_states
613
+
614
+
615
+ def swap_state_ids_states_to_tokens_map (
616
+ states_to_tokens_map : Dict [int , Dict [int , int ]],
617
+ first_state_id : int ,
618
+ second_state_id : int ,
619
+ ) -> Dict [int , Dict [int , int ]]:
620
+ """Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
621
+ first_state_transitions = states_to_tokens_map .pop (first_state_id )
622
+ second_state_transitions = states_to_tokens_map .pop (second_state_id )
623
+ states_to_tokens_map [first_state_id ] = second_state_transitions
624
+ states_to_tokens_map [second_state_id ] = first_state_transitions
625
+
626
+ for transitions in states_to_tokens_map .values ():
627
+ for token , target_state_id in list (transitions .items ()):
628
+ if target_state_id == first_state_id :
629
+ transitions [token ] = second_state_id
630
+ elif target_state_id == second_state_id :
631
+ transitions [token ] = first_state_id
632
+
633
+ return states_to_tokens_map
0 commit comments