Skip to content

Commit 4aa74f2

Browse files
committed
Implement token alignment for StopAtEosFSM and RegexFSM
1 parent e99d92d commit 4aa74f2

File tree

2 files changed

+291
-22
lines changed

2 files changed

+291
-22
lines changed

outlines/fsm/fsm.py

+225-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from typing import TYPE_CHECKING, List, NewType, Protocol, Tuple
1+
from collections import defaultdict
2+
from copy import deepcopy
3+
from typing import TYPE_CHECKING, Dict, List, NewType, Protocol, Tuple
24

35
import interegular
6+
import torch
47
from lark import Lark
58

69
# from outlines.fsm.parsing import PartialLark
@@ -22,6 +25,11 @@ def is_final_state(self, state: FSMState) -> bool:
2225
"""Determine whether the current state of the FSM is a final state."""
2326
return state == self.final_state
2427

28+
def align_prompt_tokens(
29+
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
30+
) -> Tuple[torch.Tensor, torch.Tensor]:
31+
...
32+
2533
def allowed_token_ids(self, state: FSMState) -> List[int]:
2634
...
2735

@@ -37,13 +45,41 @@ class StopAtEosFSM(FSM):
3745

3846
def __init__(self, tokenizer: "Tokenizer"):
3947
self.eos_token_id = tokenizer.eos_token_id
40-
self.vocabulary = tokenizer.vocabulary.values()
48+
self.vocabulary = tokenizer.vocabulary
49+
self.tokenizer = tokenizer
50+
self.states_to_token_maps = self.create_states_to_tokens_map()
51+
52+
def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]:
53+
"""Create the states_to_tokens_map. All tokens from the starting state lead
54+
to itself, except for the eos_token that leads to the final state."""
55+
return {
56+
self.first_state: {
57+
token_id: self.first_state
58+
if token_id != self.eos_token_id
59+
else self.final_state
60+
for token_id in self.vocabulary.values()
61+
}
62+
}
63+
64+
def align_prompt_tokens(
65+
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
66+
) -> Tuple[torch.Tensor, torch.Tensor]:
67+
"""Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
68+
(
69+
token_ids,
70+
attention_masks,
71+
self.states_to_token_maps,
72+
) = align_tokens_states_to_token_maps(
73+
token_ids, attention_masks, self.vocabulary, self.states_to_token_maps
74+
)
75+
return token_ids, attention_masks
4176

4277
def allowed_token_ids(self, state: FSMState) -> List[int]:
4378
"""Generate a list of allowed tokens for the next step.
4479
45-
When in the initial state we allow every token to be generated.
4680
In the final state the only allowed token is `stop_token_id`.
81+
Otherwise we allow the valid transitions tokens corresponding to
82+
the current state of the states_to_token_maps
4783
4884
Parameters
4985
----------
@@ -57,14 +93,13 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
5793
"""
5894
if self.is_final_state(state):
5995
return [self.eos_token_id]
60-
return list(self.vocabulary)
96+
return list(self.states_to_token_maps[state].keys())
6197

6298
def next_state(self, state: FSMState, token_id: int) -> FSMState:
6399
"""Update the state of the FSM.
64100
65-
The FSM stays in the initial state `0` unless the specified stop token
66-
has been generated or the maximum number of tokens has been reached. In
67-
which case the FSM moves to the final state `-1`.
101+
The FSM transitions from a state to the other through the
102+
states_to_token_maps until the final state is reached.
68103
69104
Parameters
70105
----------
@@ -78,14 +113,14 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
78113
The new state of the FSM.
79114
80115
"""
81-
if token_id == self.eos_token_id:
116+
if self.is_final_state(state):
82117
return self.final_state
83118

84-
return self.first_state
119+
return FSMState(self.states_to_token_maps[state][token_id])
85120

86121
def copy(self) -> "StopAtEosFSM":
87122
"""Create a copy of the FSM."""
88-
return self
123+
return deepcopy(self)
89124

90125

91126
class RegexFSM(FSM):
@@ -121,9 +156,22 @@ def create_states_mapping(
121156
self.states_to_token_maps, self.empty_token_ids = create_states_mapping(
122157
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
123158
)
124-
self.vocabulary = tokenizer.vocabulary.values()
159+
self.vocabulary = tokenizer.vocabulary
125160
self.eos_token_id = tokenizer.eos_token_id
126161

162+
def align_prompt_tokens(
163+
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
164+
) -> Tuple[torch.Tensor, torch.Tensor]:
165+
"""Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
166+
(
167+
token_ids,
168+
attention_masks,
169+
self.states_to_token_maps,
170+
) = align_tokens_states_to_token_maps(
171+
token_ids, attention_masks, self.vocabulary, self.states_to_token_maps
172+
)
173+
return token_ids, attention_masks
174+
127175
def allowed_token_ids(self, state: FSMState) -> List[int]:
128176
"""Generate a list of allowed tokens for the next step.
129177
@@ -184,7 +232,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
184232

185233
def copy(self) -> "RegexFSM":
186234
"""Create a copy of the FSM."""
187-
return self
235+
return deepcopy(self)
188236

189237

190238
class CFGFSM(FSM):
@@ -218,6 +266,12 @@ def __init__(self, cfg_string: str, tokenizer):
218266
self.proposal_last: List[int] = []
219267
self.regex_fsm_last: RegexFSM
220268

269+
def align_prompt_tokens(
270+
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
271+
) -> Tuple[torch.Tensor, torch.Tensor]:
272+
"""Not applicable to this type of FSM"""
273+
return token_ids, attention_masks
274+
221275
def allowed_token_ids(self, state: FSMState) -> List[int]:
222276
"""Generate a list of allowed tokens for the next step.
223277
@@ -333,3 +387,162 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
333387
def copy(self) -> "CFGFSM":
334388
"""Create a copy of the FSM."""
335389
return CFGFSM(self.cfg_string, self.tokenizer)
390+
391+
392+
def align_tokens_states_to_token_maps(
393+
token_ids: torch.Tensor,
394+
attention_masks: torch.Tensor,
395+
vocabulary: Dict[str, int],
396+
states_to_token_maps: Dict[int, Dict[int, int]],
397+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[int, Dict[int, int]]]:
398+
"""Apply token alignment to the provided prompt tokens and attention masks given the
399+
states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
400+
states_to_token_maps"""
401+
prompt_token_ids = token_ids.tolist()
402+
crossing_tokens = find_crossing_tokens(prompt_token_ids, vocabulary)
403+
valid_crossing_tokens = get_crossing_tokens_target_states(
404+
states_to_token_maps, crossing_tokens, prompt_token_ids, vocabulary
405+
)
406+
if not valid_crossing_tokens:
407+
return token_ids, attention_masks, states_to_token_maps
408+
(
409+
states_to_token_maps,
410+
number_cropped_tokens,
411+
) = add_crossing_tokens_states_to_tokens_map(
412+
states_to_token_maps, prompt_token_ids, valid_crossing_tokens
413+
)
414+
return (
415+
token_ids[:-number_cropped_tokens],
416+
attention_masks[:-number_cropped_tokens],
417+
states_to_token_maps,
418+
)
419+
420+
421+
def find_crossing_tokens(
422+
token_ids: List[int], vocabulary: Dict[str, int]
423+
) -> Dict[int, List[int]]:
424+
"""Find the tokens that could replace one or more tokens at the end of token_ids
425+
while conserving the same intial text (and extending it by at least one character).
426+
Return a dictionary with, for the indexes in the token_ids, the associated crossing tokens.
427+
"""
428+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
429+
len_token_ids = len(token_ids)
430+
max_length_token_text = max(len(item) for item in vocabulary.keys())
431+
characters_considered = ""
432+
crossing_tokens_map = {}
433+
434+
for index, token_id in enumerate(reversed(token_ids)):
435+
characters_considered = reversed_vocabulary[token_id] + characters_considered
436+
if len(characters_considered) >= max_length_token_text:
437+
break
438+
crossing_token_ids = [
439+
token_id
440+
for text, token_id in vocabulary.items()
441+
if text.startswith(characters_considered)
442+
and len(text) > len(characters_considered)
443+
]
444+
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids
445+
446+
return crossing_tokens_map
447+
448+
449+
def get_crossing_tokens_target_states(
450+
states_to_tokens_map: Dict[int, Dict[int, int]],
451+
crossing_tokens: Dict[int, List[int]],
452+
prompt_token_ids: List[int],
453+
vocabulary: Dict[str, int],
454+
) -> Dict[int, Dict[int, int]]:
455+
"""For each crossing token associated to an index, check that the characters after the boundary
456+
match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
457+
provided indexes, the associated valid tokens with the state they would lead to.
458+
"""
459+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
460+
prompt_token_texts = [
461+
reversed_vocabulary[token_id] for token_id in prompt_token_ids
462+
]
463+
464+
valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict)
465+
for pos, tokens in crossing_tokens.items():
466+
for token in tokens:
467+
is_valid = True
468+
characters = reversed_vocabulary[token]
469+
characters_before_border = "".join(prompt_token_texts[pos:])
470+
characters_after_border = characters[len(characters_before_border) :]
471+
state = 0
472+
for char in characters_after_border:
473+
char_token = vocabulary.get(char)
474+
try:
475+
state = states_to_tokens_map[state][char_token] # type: ignore
476+
except KeyError:
477+
is_valid = False
478+
break
479+
if is_valid:
480+
valid_crossing_tokens[pos][token] = state
481+
482+
return valid_crossing_tokens
483+
484+
485+
def add_crossing_tokens_states_to_tokens_map(
486+
states_to_tokens_map: Dict[int, Dict[int, int]],
487+
prompt_token_ids: List[int],
488+
crossing_tokens_map: Dict[int, Dict[int, int]],
489+
) -> Tuple[Dict[int, Dict[int, int]], int]:
490+
"""Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
491+
the starting state of the fsm as we would include some characters at the end of the prompt in
492+
the states_to_tokens_map.
493+
Attention! the starting state of the states_to_tokens_map provided must be 0.
494+
Return the updated states_to_tokens_map and the number of cropped tokens/additional states
495+
"""
496+
if not crossing_tokens_map:
497+
return states_to_tokens_map, 0
498+
first_crossing_token_pos = min(
499+
[key for key, value in crossing_tokens_map.items() if value]
500+
)
501+
number_additional_states = len(prompt_token_ids) - first_crossing_token_pos
502+
highest_state = max(
503+
max(states_to_tokens_map.keys()),
504+
max(max(items.values()) for items in states_to_tokens_map.values()),
505+
)
506+
507+
for i in range(number_additional_states):
508+
# add the tokens that was originally part of the prompt
509+
if i == number_additional_states - 1:
510+
states_to_tokens_map[highest_state + 1 + i] = {
511+
prompt_token_ids[first_crossing_token_pos + i]: 0
512+
}
513+
else:
514+
states_to_tokens_map[highest_state + 1 + i] = {
515+
prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i
516+
}
517+
# add the crossing tokens
518+
crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i)
519+
if crossing_tokens:
520+
for token, target_state in crossing_tokens.items():
521+
states_to_tokens_map[highest_state + 1 + i][token] = target_state
522+
523+
# set the id of our new initial state to 0
524+
states_to_tokens_map = swap_state_ids_states_to_tokens_map(
525+
states_to_tokens_map, highest_state + 1, 0
526+
)
527+
return states_to_tokens_map, number_additional_states
528+
529+
530+
def swap_state_ids_states_to_tokens_map(
531+
states_to_tokens_map: Dict[int, Dict[int, int]],
532+
first_state_id: int,
533+
second_state_id: int,
534+
) -> Dict[int, Dict[int, int]]:
535+
"""Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
536+
first_state_transitions = states_to_tokens_map.pop(first_state_id)
537+
second_state_transitions = states_to_tokens_map.pop(second_state_id)
538+
states_to_tokens_map[first_state_id] = second_state_transitions
539+
states_to_tokens_map[second_state_id] = first_state_transitions
540+
541+
for transitions in states_to_tokens_map.values():
542+
for token, target_state_id in list(transitions.items()):
543+
if target_state_id == first_state_id:
544+
transitions[token] = second_state_id
545+
elif target_state_id == second_state_id:
546+
transitions[token] = first_state_id
547+
548+
return states_to_tokens_map

0 commit comments

Comments
 (0)