Skip to content

Commit 29853ec

Browse files
RobinPicardrlouf
authored andcommitted
Align prompt and generation
1 parent 11143df commit 29853ec

File tree

4 files changed

+582
-23
lines changed

4 files changed

+582
-23
lines changed

outlines/fsm/guide.py

+224-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from collections import defaultdict
2+
from copy import deepcopy
13
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING, List, Protocol, Tuple, Union
4+
from typing import TYPE_CHECKING, Dict, List, Protocol, Tuple, Union
35

46
import interegular
7+
import torch
58
from lark import Lark
69

710
from outlines import grammars
@@ -62,11 +65,16 @@ def get_next_state(self, state: int, token_id: int) -> int:
6265
def is_final_state(self, state: int) -> bool:
6366
...
6467

68+
def align_prompt_tokens(
69+
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
70+
) -> Tuple[torch.Tensor, torch.Tensor]:
71+
...
72+
6573

6674
class StopAtEOSGuide(Guide):
6775
"""Guide to generate tokens until the EOS token has been generated."""
6876

69-
final_state = 1
77+
final_state = -1
7078
start_state = 0
7179

7280
def __init__(self, tokenizer: "Tokenizer"):
@@ -77,24 +85,52 @@ def __init__(self, tokenizer: "Tokenizer"):
7785
7886
"""
7987
self.eos_token_id = tokenizer.eos_token_id
80-
self.vocabulary = tokenizer.vocabulary.values()
88+
self.vocabulary = tokenizer.vocabulary
89+
self.tokenizer = tokenizer
90+
self.states_to_token_maps = self.create_states_to_tokens_map()
91+
92+
def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]:
93+
"""Create the states_to_tokens_map. All tokens from the starting state lead
94+
to itself, except for the eos_token that leads to the final state."""
95+
return {
96+
self.start_state: {
97+
token_id: self.start_state
98+
if token_id != self.eos_token_id
99+
else self.final_state
100+
for token_id in self.vocabulary.values()
101+
}
102+
}
103+
104+
def align_prompt_tokens(
105+
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
106+
) -> Tuple[torch.Tensor, torch.Tensor]:
107+
"""Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
108+
(
109+
token_ids,
110+
attention_masks,
111+
self.states_to_token_maps,
112+
) = align_tokens_states_to_token_maps(
113+
token_ids, attention_masks, self.vocabulary, self.states_to_token_maps
114+
)
115+
return token_ids, attention_masks
81116

82117
def get_next_instruction(self, state: int) -> Instruction:
83118
if self.is_final_state(state):
84119
return Write([self.eos_token_id])
85-
return Generate(list(self.vocabulary))
120+
121+
return Generate(list(self.states_to_token_maps[state].keys()))
86122

87123
def get_next_state(self, state: int, token_id: int) -> int:
88-
if token_id == self.eos_token_id or state == self.final_state:
124+
if self.is_final_state(state):
89125
return self.final_state
90126

91-
return self.start_state
127+
return self.states_to_token_maps[state][token_id]
92128

93129
def is_final_state(self, state: int):
94130
return state == self.final_state
95131

96132
def copy(self):
97-
return self
133+
return deepcopy(self)
98134

99135

100136
class RegexGuide(Guide):
@@ -136,10 +172,23 @@ def create_states_mapping(
136172
) = create_states_mapping(
137173
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
138174
)
139-
self.vocabulary = list(tokenizer.vocabulary.values())
175+
self.vocabulary = tokenizer.vocabulary
140176
self.eos_token_id = tokenizer.eos_token_id
141177
self.final_states = fsm_finals | {-1}
142178

179+
def align_prompt_tokens(
180+
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
181+
) -> Tuple[torch.Tensor, torch.Tensor]:
182+
"""Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
183+
(
184+
token_ids,
185+
attention_masks,
186+
self.states_to_token_maps,
187+
) = align_tokens_states_to_token_maps(
188+
token_ids, attention_masks, self.vocabulary, self.states_to_token_maps
189+
)
190+
return token_ids, attention_masks
191+
143192
def get_next_instruction(self, state: int) -> Instruction:
144193
"""Return the next instruction for guided generation.
145194
@@ -244,7 +293,7 @@ def is_final_state(self, state: int) -> bool:
244293
return state in self.final_states
245294

246295
def copy(self):
247-
return self
296+
return deepcopy(self)
248297

249298

250299
class CFGGuide(Guide):
@@ -281,6 +330,12 @@ def __init__(self, cfg_string: str, tokenizer):
281330
self.start_state = 0
282331
self.final_state = -1
283332

333+
def align_prompt_tokens(
334+
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
335+
) -> Tuple[torch.Tensor, torch.Tensor]:
336+
"""Not applicable to this type of FSM"""
337+
return token_ids, attention_masks
338+
284339
def get_next_instruction(self, state: int) -> Instruction:
285340
"""Generate an instruction for the next step.
286341
@@ -416,3 +471,163 @@ def is_final_state(self, state: int) -> bool:
416471
def copy(self) -> "CFGGuide":
417472
"""Create a copy of the FSM."""
418473
return CFGGuide(self.cfg_string, self.tokenizer)
474+
475+
476+
def align_tokens_states_to_token_maps(
477+
token_ids: torch.Tensor,
478+
attention_masks: torch.Tensor,
479+
vocabulary: Dict[str, int],
480+
states_to_token_maps: Dict[int, Dict[int, int]],
481+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[int, Dict[int, int]]]:
482+
"""Apply token alignment to the provided prompt tokens and attention masks given the
483+
states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
484+
states_to_token_maps"""
485+
prompt_token_ids = token_ids.tolist()
486+
crossing_tokens = find_crossing_tokens(prompt_token_ids, vocabulary)
487+
valid_crossing_tokens = get_crossing_tokens_target_states(
488+
states_to_token_maps, crossing_tokens, prompt_token_ids, vocabulary
489+
)
490+
if not valid_crossing_tokens:
491+
return token_ids, attention_masks, states_to_token_maps
492+
(
493+
states_to_token_maps,
494+
number_cropped_tokens,
495+
) = add_crossing_tokens_states_to_tokens_map(
496+
states_to_token_maps, prompt_token_ids, valid_crossing_tokens
497+
)
498+
return (
499+
token_ids[:-number_cropped_tokens],
500+
attention_masks[:-number_cropped_tokens],
501+
states_to_token_maps,
502+
)
503+
504+
505+
def find_crossing_tokens(
506+
token_ids: List[int], vocabulary: Dict[str, int]
507+
) -> Dict[int, List[int]]:
508+
"""Find the tokens that could replace one or more tokens at the end of token_ids
509+
while conserving the same intial text (and extending it by at least one character).
510+
Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
511+
"""
512+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
513+
len_token_ids = len(token_ids)
514+
max_length_token_text = max(len(item) for item in vocabulary.keys())
515+
characters_considered = ""
516+
crossing_tokens_map = {}
517+
518+
for index, token_id in enumerate(reversed(token_ids)):
519+
characters_considered = reversed_vocabulary[token_id] + characters_considered
520+
if len(characters_considered) >= max_length_token_text:
521+
break
522+
crossing_token_ids = [
523+
token_id
524+
for text, token_id in vocabulary.items()
525+
if text.startswith(characters_considered)
526+
and len(text) > len(characters_considered)
527+
]
528+
if crossing_token_ids:
529+
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids
530+
531+
return crossing_tokens_map
532+
533+
534+
def get_crossing_tokens_target_states(
535+
states_to_tokens_map: Dict[int, Dict[int, int]],
536+
crossing_tokens: Dict[int, List[int]],
537+
prompt_token_ids: List[int],
538+
vocabulary: Dict[str, int],
539+
) -> Dict[int, Dict[int, int]]:
540+
"""For each crossing token associated to an index, check that the characters after the boundary
541+
match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
542+
provided indexes, the associated valid tokens with the state they would lead to.
543+
"""
544+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
545+
prompt_token_texts = [
546+
reversed_vocabulary[token_id] for token_id in prompt_token_ids
547+
]
548+
549+
valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict)
550+
for pos, tokens in crossing_tokens.items():
551+
for token in tokens:
552+
is_valid = True
553+
characters = reversed_vocabulary[token]
554+
characters_before_border = "".join(prompt_token_texts[pos:])
555+
characters_after_border = characters[len(characters_before_border) :]
556+
state = 0
557+
for char in characters_after_border:
558+
char_token = vocabulary.get(char)
559+
try:
560+
state = states_to_tokens_map[state][char_token] # type: ignore
561+
except KeyError:
562+
is_valid = False
563+
break
564+
if is_valid:
565+
valid_crossing_tokens[pos][token] = state
566+
567+
return valid_crossing_tokens
568+
569+
570+
def add_crossing_tokens_states_to_tokens_map(
571+
states_to_tokens_map: Dict[int, Dict[int, int]],
572+
prompt_token_ids: List[int],
573+
crossing_tokens_map: Dict[int, Dict[int, int]],
574+
) -> Tuple[Dict[int, Dict[int, int]], int]:
575+
"""Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
576+
the starting state of the fsm as we would include some characters at the end of the prompt in
577+
the states_to_tokens_map.
578+
Attention! the starting state of the states_to_tokens_map provided must be 0.
579+
Return the updated states_to_tokens_map and the number of cropped tokens/additional states
580+
"""
581+
if not crossing_tokens_map:
582+
return states_to_tokens_map, 0
583+
first_crossing_token_pos = min(
584+
[key for key, value in crossing_tokens_map.items() if value]
585+
)
586+
number_additional_states = len(prompt_token_ids) - first_crossing_token_pos
587+
highest_state = max(
588+
max(states_to_tokens_map.keys()),
589+
max(max(items.values()) for items in states_to_tokens_map.values()),
590+
)
591+
592+
for i in range(number_additional_states):
593+
# add the tokens that was originally part of the prompt
594+
if i == number_additional_states - 1:
595+
states_to_tokens_map[highest_state + 1 + i] = {
596+
prompt_token_ids[first_crossing_token_pos + i]: 0
597+
}
598+
else:
599+
states_to_tokens_map[highest_state + 1 + i] = {
600+
prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i
601+
}
602+
# add the crossing tokens
603+
crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i)
604+
if crossing_tokens:
605+
for token, target_state in crossing_tokens.items():
606+
states_to_tokens_map[highest_state + 1 + i][token] = target_state
607+
608+
# set the id of our new initial state to 0
609+
states_to_tokens_map = swap_state_ids_states_to_tokens_map(
610+
states_to_tokens_map, highest_state + 1, 0
611+
)
612+
return states_to_tokens_map, number_additional_states
613+
614+
615+
def swap_state_ids_states_to_tokens_map(
616+
states_to_tokens_map: Dict[int, Dict[int, int]],
617+
first_state_id: int,
618+
second_state_id: int,
619+
) -> Dict[int, Dict[int, int]]:
620+
"""Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
621+
first_state_transitions = states_to_tokens_map.pop(first_state_id)
622+
second_state_transitions = states_to_tokens_map.pop(second_state_id)
623+
states_to_tokens_map[first_state_id] = second_state_transitions
624+
states_to_tokens_map[second_state_id] = first_state_transitions
625+
626+
for transitions in states_to_tokens_map.values():
627+
for token, target_state_id in list(transitions.items()):
628+
if target_state_id == first_state_id:
629+
transitions[token] = second_state_id
630+
elif target_state_id == second_state_id:
631+
transitions[token] = first_state_id
632+
633+
return states_to_tokens_map

0 commit comments

Comments
 (0)