Skip to content

Commit b74dc8e

Browse files
author
Robin Picard
committed
Implement prompt token alignment
1 parent b75beeb commit b74dc8e

File tree

6 files changed

+623
-24
lines changed

6 files changed

+623
-24
lines changed

outlines/fsm/guide.py

+214-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from collections import defaultdict
2+
from copy import deepcopy
13
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union
4+
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol, Tuple, Union
35

46
import interegular
7+
import torch
58
from lark import Lark
69

710
from outlines import grammars
@@ -67,14 +70,17 @@ def get_next_state(self, state: int, token_id: int) -> int:
6770
def is_final_state(self, state: int) -> bool:
6871
...
6972

73+
def align_prompt_tokens(self, prompt: str) -> str:
74+
...
75+
7076
def copy(self) -> "Guide":
7177
...
7278

7379

7480
class StopAtEOSGuide(Guide):
7581
"""Guide to generate tokens until the EOS token has been generated."""
7682

77-
final_state = 1
83+
final_state = -1
7884
start_state = 0
7985

8086
def __init__(self, tokenizer: "Tokenizer"):
@@ -85,24 +91,49 @@ def __init__(self, tokenizer: "Tokenizer"):
8591
8692
"""
8793
self.eos_token_id = tokenizer.eos_token_id
88-
self.vocabulary = tokenizer.vocabulary.values()
94+
self.tokenizer = tokenizer
95+
self.states_to_token_maps = self.create_states_to_tokens_map()
96+
97+
def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]:
98+
"""Create the states_to_tokens_map. All tokens lead to the starting
99+
state, except for the eos_token that leads to the final state."""
100+
return {
101+
self.start_state: {
102+
token_id: self.start_state
103+
if token_id != self.eos_token_id
104+
else self.final_state
105+
for token_id in self.tokenizer.vocabulary.values()
106+
}
107+
}
108+
109+
def align_prompt_tokens(self, prompt: str) -> str:
110+
"""Update the states_to_token_maps and return the aligned prompt"""
111+
token_ids, _ = self.tokenizer.encode(prompt)
112+
(
113+
aligned_token_ids,
114+
self.states_to_token_maps,
115+
) = align_tokens_states_to_token_maps(
116+
token_ids[0], self.tokenizer.vocabulary, self.states_to_token_maps
117+
)
118+
decoded_aligned_token_ids = self.tokenizer.decode(aligned_token_ids)
119+
return "".join(decoded_aligned_token_ids)
89120

90121
def get_next_instruction(self, state: int) -> Instruction:
91122
if self.is_final_state(state):
92123
return Write([self.eos_token_id])
93-
return Generate(None)
124+
return Generate(list(self.states_to_token_maps[state].keys()))
94125

95126
def get_next_state(self, state: int, token_id: int) -> int:
96-
if token_id == self.eos_token_id or state == self.final_state:
127+
if self.is_final_state(state):
97128
return self.final_state
98129

99-
return self.start_state
130+
return self.states_to_token_maps[state][token_id]
100131

101132
def is_final_state(self, state: int):
102133
return state == self.final_state
103134

104135
def copy(self):
105-
return self
136+
return deepcopy(self)
106137

107138

108139
@cache()
@@ -143,9 +174,22 @@ def __init__(self, regex_string: str, tokenizer):
143174
self.empty_token_ids,
144175
fsm_finals,
145176
) = create_states_mapping(regex_string, tokenizer)
177+
self.tokenizer = tokenizer
146178
self.eos_token_id = tokenizer.eos_token_id
147179
self.final_states = fsm_finals | {-1}
148180

181+
def align_prompt_tokens(self, prompt: str) -> str:
182+
"""Update the states_to_token_maps and return the aligned prompt"""
183+
token_ids, _ = self.tokenizer.encode(prompt)
184+
(
185+
aligned_token_ids,
186+
self.states_to_token_maps,
187+
) = align_tokens_states_to_token_maps(
188+
token_ids[0], self.tokenizer.vocabulary, self.states_to_token_maps
189+
)
190+
decoded_aligned_token_ids = self.tokenizer.decode(aligned_token_ids)
191+
return "".join(decoded_aligned_token_ids)
192+
149193
def get_next_instruction(self, state: int) -> Instruction:
150194
"""Return the next instruction for guided generation.
151195
@@ -246,7 +290,7 @@ def is_final_state(self, state: int) -> bool:
246290
return state in self.final_states
247291

248292
def copy(self):
249-
return self
293+
return deepcopy(self)
250294

251295

252296
class CFGGuide(Guide):
@@ -283,6 +327,10 @@ def __init__(self, cfg_string: str, tokenizer):
283327
self.start_state = 0
284328
self.final_state = -1
285329

330+
def align_prompt_tokens(self, prompt: str) -> str:
331+
"""Not applicable to this type of Guide"""
332+
return prompt
333+
286334
def get_next_instruction(self, state: int) -> Instruction:
287335
"""Generate an instruction for the next step.
288336
@@ -424,3 +472,161 @@ def is_final_state(self, state: int) -> bool:
424472
def copy(self) -> "CFGGuide":
425473
"""Create a copy of the FSM."""
426474
return CFGGuide(self.cfg_string, self.tokenizer)
475+
476+
477+
def align_tokens_states_to_token_maps(
478+
token_ids: torch.Tensor,
479+
vocabulary: Dict[str, int],
480+
states_to_token_maps: Dict[int, Dict[int, int]],
481+
) -> Tuple[torch.Tensor, Dict[int, Dict[int, int]]]:
482+
"""Apply token alignment to the provided prompt tokens and attention masks given the
483+
states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
484+
states_to_token_maps"""
485+
prompt_token_ids = token_ids.tolist()
486+
crossing_tokens = find_crossing_tokens(prompt_token_ids, vocabulary)
487+
valid_crossing_tokens = get_crossing_tokens_target_states(
488+
states_to_token_maps, crossing_tokens, prompt_token_ids, vocabulary
489+
)
490+
if not valid_crossing_tokens:
491+
return token_ids, states_to_token_maps
492+
(
493+
states_to_token_maps,
494+
number_cropped_tokens,
495+
) = add_crossing_tokens_states_to_tokens_map(
496+
states_to_token_maps, prompt_token_ids, valid_crossing_tokens
497+
)
498+
return (
499+
token_ids[:-number_cropped_tokens],
500+
states_to_token_maps,
501+
)
502+
503+
504+
def find_crossing_tokens(
505+
token_ids: List[int], vocabulary: Dict[str, int]
506+
) -> Dict[int, List[int]]:
507+
"""Find the tokens that could replace one or more tokens at the end of token_ids
508+
while conserving the same intial text (and extending it by at least one character).
509+
Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
510+
"""
511+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
512+
len_token_ids = len(token_ids)
513+
max_length_token_text = max(len(item) for item in vocabulary.keys())
514+
characters_considered = ""
515+
crossing_tokens_map = {}
516+
517+
for index, token_id in enumerate(reversed(token_ids)):
518+
characters_considered = reversed_vocabulary[token_id] + characters_considered
519+
if len(characters_considered) >= max_length_token_text:
520+
break
521+
crossing_token_ids = [
522+
token_id
523+
for text, token_id in vocabulary.items()
524+
if text.startswith(characters_considered)
525+
and len(text) > len(characters_considered)
526+
]
527+
if crossing_token_ids:
528+
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids
529+
530+
return crossing_tokens_map
531+
532+
533+
def get_crossing_tokens_target_states(
534+
states_to_tokens_map: Dict[int, Dict[int, int]],
535+
crossing_tokens: Dict[int, List[int]],
536+
prompt_token_ids: List[int],
537+
vocabulary: Dict[str, int],
538+
) -> Dict[int, Dict[int, int]]:
539+
"""For each crossing token associated to an index, check that the characters after the boundary
540+
match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
541+
provided indexes, the associated valid tokens with the state they would lead to.
542+
"""
543+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
544+
prompt_token_texts = [
545+
reversed_vocabulary[token_id] for token_id in prompt_token_ids
546+
]
547+
548+
valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict)
549+
for pos, tokens in crossing_tokens.items():
550+
for token in tokens:
551+
is_valid = True
552+
characters = reversed_vocabulary[token]
553+
characters_before_border = "".join(prompt_token_texts[pos:])
554+
characters_after_border = characters[len(characters_before_border) :]
555+
state = 0
556+
for char in characters_after_border:
557+
char_token = vocabulary.get(char)
558+
try:
559+
state = states_to_tokens_map[state][char_token] # type: ignore
560+
except KeyError:
561+
is_valid = False
562+
break
563+
if is_valid:
564+
valid_crossing_tokens[pos][token] = state
565+
566+
return valid_crossing_tokens
567+
568+
569+
def add_crossing_tokens_states_to_tokens_map(
570+
states_to_tokens_map: Dict[int, Dict[int, int]],
571+
prompt_token_ids: List[int],
572+
crossing_tokens_map: Dict[int, Dict[int, int]],
573+
) -> Tuple[Dict[int, Dict[int, int]], int]:
574+
"""Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
575+
the starting state of the fsm as we would include some characters at the end of the prompt in
576+
the states_to_tokens_map.
577+
Attention! the starting state of the states_to_tokens_map provided must be 0.
578+
Return the updated states_to_tokens_map and the number of cropped tokens/additional states
579+
"""
580+
if not crossing_tokens_map:
581+
return states_to_tokens_map, 0
582+
first_crossing_token_pos = min(
583+
[key for key, value in crossing_tokens_map.items() if value]
584+
)
585+
number_additional_states = len(prompt_token_ids) - first_crossing_token_pos
586+
highest_state = max(
587+
max(states_to_tokens_map.keys()),
588+
max(max(items.values()) for items in states_to_tokens_map.values()),
589+
)
590+
591+
for i in range(number_additional_states):
592+
# add the tokens that was originally part of the prompt
593+
if i == number_additional_states - 1:
594+
states_to_tokens_map[highest_state + 1 + i] = {
595+
prompt_token_ids[first_crossing_token_pos + i]: 0
596+
}
597+
else:
598+
states_to_tokens_map[highest_state + 1 + i] = {
599+
prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i
600+
}
601+
# add the crossing tokens
602+
crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i)
603+
if crossing_tokens:
604+
for token, target_state in crossing_tokens.items():
605+
states_to_tokens_map[highest_state + 1 + i][token] = target_state
606+
607+
# set the id of our new initial state to 0
608+
states_to_tokens_map = swap_state_ids_states_to_tokens_map(
609+
states_to_tokens_map, highest_state + 1, 0
610+
)
611+
return states_to_tokens_map, number_additional_states
612+
613+
614+
def swap_state_ids_states_to_tokens_map(
615+
states_to_tokens_map: Dict[int, Dict[int, int]],
616+
first_state_id: int,
617+
second_state_id: int,
618+
) -> Dict[int, Dict[int, int]]:
619+
"""Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
620+
first_state_transitions = states_to_tokens_map.pop(first_state_id)
621+
second_state_transitions = states_to_tokens_map.pop(second_state_id)
622+
states_to_tokens_map[first_state_id] = second_state_transitions
623+
states_to_tokens_map[second_state_id] = first_state_transitions
624+
625+
for transitions in states_to_tokens_map.values():
626+
for token, target_state_id in list(transitions.items()):
627+
if target_state_id == first_state_id:
628+
transitions[token] = second_state_id
629+
elif target_state_id == second_state_id:
630+
transitions[token] = first_state_id
631+
632+
return states_to_tokens_map

0 commit comments

Comments
 (0)