Skip to content

Commit 017597a

Browse files
committed
Create align prompt tokens feature of Guide classes
1 parent af74a0c commit 017597a

File tree

3 files changed

+534
-29
lines changed

3 files changed

+534
-29
lines changed

outlines/fsm/guide.py

+230-25
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
from copy import copy, deepcopy
13
from dataclasses import dataclass
24
from typing import (
35
TYPE_CHECKING,
@@ -69,6 +71,9 @@ class Guide(Protocol):
6971
7072
"""
7173

74+
start_state: int = 0
75+
final_state: int = -1
76+
7277
def get_next_instruction(self, state: int) -> Instruction:
7378
...
7479

@@ -82,11 +87,39 @@ def copy(self) -> "Guide":
8287
...
8388

8489

85-
class StopAtEOSGuide(Guide):
86-
"""Guide to generate tokens until the EOS token has been generated."""
90+
class TokenHealerMixin:
91+
"""Class used to add the token align feature to a Guide"""
8792

88-
final_state = 1
89-
start_state = 0
93+
states_to_token_maps: Dict[int, Dict[int, int]]
94+
tokenizer: "Tokenizer"
95+
96+
def align_prompt_tokens(self, prompt: str) -> str:
97+
"""Update the states_to_token_maps and return the aligned prompt"""
98+
token_ids, _ = self.tokenizer.encode(prompt)
99+
(
100+
aligned_token_ids,
101+
aligned_states_to_token_maps,
102+
) = align_tokens_states_to_token_maps(
103+
token_ids.tolist()[0],
104+
self.tokenizer.vocabulary,
105+
deepcopy(self.states_to_token_maps),
106+
)
107+
aligned_prompt = self.tokenizer.decode([aligned_token_ids])[0]
108+
# some models do not accept an empty string as a prompt
109+
# if token alignement would remove all tokens, do not apply it
110+
if not aligned_prompt:
111+
return prompt
112+
self.states_to_token_maps = aligned_states_to_token_maps
113+
if hasattr(self, "_cache_state_to_token_tensor"):
114+
self._cache_state_to_token_tensor()
115+
# remove leading whitespace if added by the tokenizer
116+
if aligned_prompt[0] == " " and prompt[0] != " ":
117+
return aligned_prompt[1:]
118+
return aligned_prompt
119+
120+
121+
class StopAtEOSGuide(Guide, TokenHealerMixin):
122+
"""Guide to generate tokens until the EOS token has been generated."""
90123

91124
def __init__(self, tokenizer: "Tokenizer"):
92125
"""Initialize the generation guide.
@@ -95,25 +128,37 @@ def __init__(self, tokenizer: "Tokenizer"):
95128
The logit generator used to generate the next token.
96129
97130
"""
98-
self.eos_token_id = tokenizer.eos_token_id
99-
self.vocabulary = tokenizer.vocabulary.values()
131+
self.tokenizer = tokenizer
132+
self.states_to_token_maps = self.create_states_to_tokens_map()
133+
134+
def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]:
135+
"""Create the states_to_tokens_map. All tokens lead to the starting
136+
state, except for the eos_token that leads to the final state."""
137+
return {
138+
self.start_state: {
139+
token_id: self.start_state
140+
if token_id != self.tokenizer.eos_token_id
141+
else self.final_state
142+
for token_id in self.tokenizer.vocabulary.values()
143+
}
144+
}
100145

101146
def get_next_instruction(self, state: int) -> Instruction:
102147
if self.is_final_state(state):
103-
return Write([self.eos_token_id])
104-
return Generate(None)
148+
return Write([self.tokenizer.eos_token_id])
149+
return Generate(list(self.states_to_token_maps[state].keys()))
105150

106151
def get_next_state(self, state: int, token_id: int) -> int:
107-
if token_id == self.eos_token_id or state == self.final_state:
152+
if self.is_final_state(state):
108153
return self.final_state
109154

110-
return self.start_state
155+
return self.states_to_token_maps[state][token_id]
111156

112157
def is_final_state(self, state: int):
113158
return state == self.final_state
114159

115160
def copy(self):
116-
return self
161+
return copy(self)
117162

118163

119164
@cache()
@@ -171,20 +216,20 @@ def create_states_mapping(
171216
return states_to_token_maps, empty_token_ids, regex_fsm.finals
172217

173218

174-
class RegexGuide(Guide):
219+
class RegexGuide(Guide, TokenHealerMixin):
175220
"""Guide to generate text in the language of a regular expression."""
176221

177-
initial_state = 0
222+
states_to_token_mask: Dict[int, torch.Tensor]
178223

179224
def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
225+
self.tokenizer = tokenizer
180226
(
181227
self.states_to_token_maps,
182228
self.empty_token_ids,
183229
fsm_finals,
184230
) = create_states_mapping(regex_string, tokenizer)
185-
self.eos_token_id = tokenizer.eos_token_id
186-
self.final_states = fsm_finals | {-1}
187231
self._cache_state_to_token_tensor()
232+
self.final_states = fsm_finals | {self.final_state}
188233

189234
def get_next_instruction(self, state: int) -> Instruction:
190235
"""Return the next instruction for guided generation.
@@ -211,7 +256,7 @@ def get_next_instruction(self, state: int) -> Instruction:
211256
"""
212257
next_tokens_mask = self.states_to_token_mask.get(state)
213258
if next_tokens_mask is None:
214-
return Write(torch.tensor([self.eos_token_id]))
259+
return Write(torch.tensor([self.tokenizer.eos_token_id]))
215260

216261
return Generate(next_tokens_mask)
217262

@@ -233,13 +278,16 @@ def get_next_state(self, state: int, token_id: int) -> int:
233278
The new state of the guide.
234279
235280
"""
236-
if token_id == self.eos_token_id or state not in self.states_to_token_maps:
237-
return -1
281+
if (
282+
token_id == self.tokenizer.eos_token_id
283+
or state not in self.states_to_token_maps
284+
):
285+
return self.final_state
238286

239287
last_token_to_end_state = self.states_to_token_maps[state]
240288
next_state = last_token_to_end_state.get(token_id)
241289
if next_state is None:
242-
next_state = -1
290+
next_state = self.final_state
243291

244292
return next_state
245293

@@ -278,11 +326,11 @@ def create_states_mapping_from_interegular_fsm(
278326
from_interegular_instance.states_to_token_maps,
279327
from_interegular_instance.empty_token_ids,
280328
) = create_states_mapping_from_interegular_fsm(interegular_fsm)
281-
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
329+
from_interegular_instance.tokenizer = tokenizer
282330
from_interegular_instance._cache_state_to_token_tensor()
283331
return from_interegular_instance
284332

285-
def _cache_state_to_token_tensor(self):
333+
def _cache_state_to_token_tensor(self) -> None:
286334
"""
287335
cache state -> token int tensor
288336
this increases performance of mask construction substantially
@@ -297,7 +345,7 @@ def is_final_state(self, state: int) -> bool:
297345
return state in self.final_states
298346

299347
def copy(self):
300-
return self
348+
return copy(self)
301349

302350

303351
class CFGGuide(Guide):
@@ -331,9 +379,6 @@ def __init__(self, cfg_string: str, tokenizer):
331379
self.proposal_last: List[int] = []
332380
self.regex_fsm_last: RegexGuide
333381

334-
self.start_state = 0
335-
self.final_state = -1
336-
337382
def get_next_instruction(self, state: int) -> Instruction:
338383
"""Generate an instruction for the next step.
339384
@@ -475,3 +520,163 @@ def is_final_state(self, state: int) -> bool:
475520
def copy(self) -> "CFGGuide":
476521
"""Create a copy of the FSM."""
477522
return CFGGuide(self.cfg_string, self.tokenizer)
523+
524+
525+
def align_tokens_states_to_token_maps(
526+
token_ids: List[int],
527+
vocabulary: Dict[str, int],
528+
states_to_token_maps: Dict[int, Dict[int, int]],
529+
) -> Tuple[List[int], Dict[int, Dict[int, int]]]:
530+
"""Apply token alignment to the provided prompt tokens and attention masks given the
531+
states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
532+
states_to_token_maps. You can find an explanation from Guidance on why token healing
533+
is necessary here:
534+
https://github.com/guidance-ai/guidance/blob/main/notebooks/tutorials/token_healing.ipynb
535+
"""
536+
crossing_tokens = find_crossing_tokens(token_ids, vocabulary)
537+
valid_crossing_tokens = get_crossing_tokens_target_states(
538+
states_to_token_maps, crossing_tokens, token_ids, vocabulary
539+
)
540+
if not valid_crossing_tokens:
541+
return token_ids, states_to_token_maps
542+
(
543+
states_to_token_maps,
544+
number_cropped_tokens,
545+
) = add_crossing_tokens_states_to_tokens_map(
546+
states_to_token_maps, token_ids, valid_crossing_tokens
547+
)
548+
return (
549+
token_ids[:-number_cropped_tokens],
550+
states_to_token_maps,
551+
)
552+
553+
554+
def find_crossing_tokens(
555+
token_ids: List[int], vocabulary: Dict[str, int]
556+
) -> Dict[int, List[int]]:
557+
"""Find the tokens that could replace one or more tokens at the end of token_ids
558+
while conserving the same intial text (and extending it by at least one character).
559+
Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
560+
"""
561+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
562+
len_token_ids = len(token_ids)
563+
max_length_token_text = max(len(item) for item in vocabulary.keys())
564+
characters_considered = ""
565+
crossing_tokens_map = {}
566+
567+
for index, token_id in enumerate(reversed(token_ids)):
568+
characters_considered = reversed_vocabulary[token_id] + characters_considered
569+
if len(characters_considered) >= max_length_token_text:
570+
break
571+
crossing_token_ids = [
572+
token_id
573+
for text, token_id in vocabulary.items()
574+
if text.startswith(characters_considered)
575+
and len(text) > len(characters_considered)
576+
]
577+
if crossing_token_ids:
578+
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids
579+
580+
return crossing_tokens_map
581+
582+
583+
def get_crossing_tokens_target_states(
584+
states_to_tokens_map: Dict[int, Dict[int, int]],
585+
crossing_tokens: Dict[int, List[int]],
586+
prompt_token_ids: List[int],
587+
vocabulary: Dict[str, int],
588+
) -> Dict[int, Dict[int, int]]:
589+
"""For each crossing token associated to an index, check that the characters after the boundary
590+
match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
591+
provided indexes, the associated valid tokens with the state they would lead to.
592+
"""
593+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
594+
prompt_token_texts = [
595+
reversed_vocabulary[token_id] for token_id in prompt_token_ids
596+
]
597+
598+
valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict)
599+
for pos, tokens in crossing_tokens.items():
600+
for token in tokens:
601+
is_valid = True
602+
characters = reversed_vocabulary[token]
603+
characters_before_border = "".join(prompt_token_texts[pos:])
604+
characters_after_border = characters[len(characters_before_border) :]
605+
state = 0
606+
for char in characters_after_border:
607+
char_token = vocabulary.get(char)
608+
try:
609+
state = states_to_tokens_map[state][char_token] # type: ignore
610+
except KeyError:
611+
is_valid = False
612+
break
613+
if is_valid:
614+
valid_crossing_tokens[pos][token] = state
615+
616+
return valid_crossing_tokens
617+
618+
619+
def add_crossing_tokens_states_to_tokens_map(
620+
states_to_tokens_map: Dict[int, Dict[int, int]],
621+
prompt_token_ids: List[int],
622+
crossing_tokens_map: Dict[int, Dict[int, int]],
623+
) -> Tuple[Dict[int, Dict[int, int]], int]:
624+
"""Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
625+
the starting state of the fsm as we would include some characters at the end of the prompt in
626+
the states_to_tokens_map.
627+
Attention! the starting state of the states_to_tokens_map provided must be 0.
628+
Return the updated states_to_tokens_map and the number of cropped tokens/additional states
629+
"""
630+
if not crossing_tokens_map:
631+
return states_to_tokens_map, 0
632+
first_crossing_token_pos = min(
633+
[key for key, value in crossing_tokens_map.items() if value]
634+
)
635+
number_additional_states = len(prompt_token_ids) - first_crossing_token_pos
636+
highest_state = max(
637+
max(states_to_tokens_map.keys()),
638+
max(max(items.values()) for items in states_to_tokens_map.values()),
639+
)
640+
641+
for i in range(number_additional_states):
642+
# add the tokens that was originally part of the prompt
643+
if i == number_additional_states - 1:
644+
states_to_tokens_map[highest_state + 1 + i] = {
645+
prompt_token_ids[first_crossing_token_pos + i]: 0
646+
}
647+
else:
648+
states_to_tokens_map[highest_state + 1 + i] = {
649+
prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i
650+
}
651+
# add the crossing tokens
652+
crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i)
653+
if crossing_tokens:
654+
for token, target_state in crossing_tokens.items():
655+
states_to_tokens_map[highest_state + 1 + i][token] = target_state
656+
657+
# set the id of our new initial state to 0
658+
states_to_tokens_map = swap_state_ids_states_to_tokens_map(
659+
states_to_tokens_map, highest_state + 1, 0
660+
)
661+
return states_to_tokens_map, number_additional_states
662+
663+
664+
def swap_state_ids_states_to_tokens_map(
665+
states_to_tokens_map: Dict[int, Dict[int, int]],
666+
first_state_id: int,
667+
second_state_id: int,
668+
) -> Dict[int, Dict[int, int]]:
669+
"""Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
670+
first_state_transitions = states_to_tokens_map.pop(first_state_id)
671+
second_state_transitions = states_to_tokens_map.pop(second_state_id)
672+
states_to_tokens_map[first_state_id] = second_state_transitions
673+
states_to_tokens_map[second_state_id] = first_state_transitions
674+
675+
for transitions in states_to_tokens_map.values():
676+
for token, target_state_id in list(transitions.items()):
677+
if target_state_id == first_state_id:
678+
transitions[token] = second_state_id
679+
elif target_state_id == second_state_id:
680+
transitions[token] = first_state_id
681+
682+
return states_to_tokens_map

tests/fsm/test_fsm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class MockTokenizer:
1818
with pytest.warns(UserWarning):
1919
fsm = StopAtEosFSM(MockTokenizer())
2020

21-
assert fsm.allowed_token_ids(fsm.start_state) is None
21+
assert fsm.allowed_token_ids(fsm.start_state) == [1, 2]
2222
assert fsm.allowed_token_ids(fsm.final_state) == [2]
2323
assert fsm.next_state(fsm.start_state, 2) == fsm.final_state
2424
assert fsm.next_state(fsm.start_state, 1) == fsm.start_state

0 commit comments

Comments
 (0)