Skip to content

Commit 01bfc21

Browse files
committed
Implement token alignment for RegexFSM and StopAtFSM
1 parent a04e8d4 commit 01bfc21

File tree

2 files changed

+89
-44
lines changed

2 files changed

+89
-44
lines changed

outlines/fsm/fsm.py

+76-39
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from copy import deepcopy
2-
from typing import TYPE_CHECKING, List, NewType, Protocol
2+
from typing import TYPE_CHECKING, Dict, List, NewType, Protocol, Tuple
33

4-
import cloudpickle
54
import interegular
65
from lark import Lark
76

@@ -15,6 +14,9 @@
1514

1615

1716
class FSM(Protocol):
17+
def align_prompt_tokens(self, prompt: str) -> str:
18+
...
19+
1820
def allowed_token_ids(self, state: FSMState) -> List[int]:
1921
...
2022

@@ -39,8 +41,23 @@ class StopAtTokenFSM(FSM):
3941

4042
def __init__(self, tokenizer: "Tokenizer", stop_token_id: int):
4143
self.stop_token_id = stop_token_id
42-
self.vocabulary = tokenizer.vocabulary.values()
43-
self.final_states = {1}
44+
self.tokenizer = tokenizer
45+
self.vocabulary = tokenizer.vocabulary
46+
self.final_states = {2}
47+
self.valid_alignment_tokens: List[int] = []
48+
49+
def align_prompt_tokens(self, prompt: str) -> str:
50+
"""Remove the last token from the prompt and set the value of self.valid_alignment_tokens"""
51+
token_ids, _ = self.tokenizer.encode(prompt)
52+
last_token_id = int(token_ids[0][-1])
53+
last_token_text = self.tokenizer.decode([last_token_id])[0]
54+
# select the tokens that start with the text removed from the prompt
55+
self.valid_alignment_tokens = [
56+
token
57+
for text, token in self.vocabulary.items()
58+
if text.startswith(last_token_text)
59+
]
60+
return prompt[: -len(last_token_text)]
4461

4562
def allowed_token_ids(self, state: FSMState) -> List[int]:
4663
"""Generate a list of allowed tokens for the next step.
@@ -59,7 +76,9 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
5976
6077
"""
6178
if state == 0:
62-
return list(self.vocabulary)
79+
return self.valid_alignment_tokens
80+
elif state == 1:
81+
return list(self.vocabulary.values())
6382
else:
6483
return [self.stop_token_id]
6584

@@ -83,17 +102,17 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
83102
84103
"""
85104
if token_id == self.stop_token_id:
86-
return FSMState(1)
105+
return FSMState(2)
87106

88-
return FSMState(0)
107+
return FSMState(1)
89108

90109
def is_final_state(self, state: FSMState) -> bool:
91110
"""Determine whether the current state of the FSM is a final state."""
92111
return state in self.final_states
93112

94113
def copy(self) -> "StopAtTokenFSM":
95114
"""Create a copy of the FSM."""
96-
return self
115+
return deepcopy(self)
97116

98117

99118
class RegexFSM(FSM):
@@ -122,41 +141,61 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
122141
-1
123142
} # Include the EOS token in final states
124143
self.tokenizer = tokenizer
125-
self.vocabulary = tokenizer.vocabulary.values()
144+
self.vocabulary = tokenizer.vocabulary
126145
self.end_token_id = tokenizer.eos_token_id
127146

128147
def align_prompt_tokens(self, prompt: str) -> str:
129148
"""Remove the last token from the prompt and update the states_to_token_maps accordingly"""
130149
token_ids, _ = self.tokenizer.encode(prompt)
131150
last_token_id = int(token_ids[0][-1])
132151
last_token_text = self.tokenizer.decode([last_token_id])[0]
133-
vocabulary = {
134-
self.tokenizer.decode([token_id])[0]: token_id
135-
for token_id in range(len(self.vocabulary))
136-
}
137-
starting_state_tokens = {
138-
self.tokenizer.decode([token_id])[0]: self.states_to_token_maps[0][token_id]
139-
for token_id in self.states_to_token_maps[0]
140-
}
141-
# select the tokens that start with the text removed from the prompt and whose text after the
142-
# initial prompt corresponds to that of one of the allowed tokens of the starting state
143-
possible_tokens = {
144-
vocabulary[token_text]: starting_state_tokens[token_text[len(last_token_text):]]
145-
for token_text in vocabulary
146-
if (
147-
token_text.startswith(last_token_text)
148-
and starting_state_tokens.get(token_text[len(last_token_text):])
149-
)
152+
last_token_length = len(last_token_text)
153+
# select the tokens that start with the text removed from the prompt
154+
crossing_tokens = {
155+
token: text
156+
for text, token in self.vocabulary.items()
157+
if text.startswith(last_token_text)
150158
}
159+
# keep only the tokens whose text after the boundary matches the fsm
160+
valid_tokens_states = self.find_valid_crossing_tokens(
161+
crossing_tokens, last_token_length
162+
)
151163
# update the states_to_token_maps in the following manner:
152164
# the value of the starting state is assigned to a new state, the starting state is now the
153-
# possible_tokens found above + the last_token we removed (that leads to the new state)
154-
additional_state_id = max(list(self.states_to_token_maps.keys()) + list(self.final_states)) + 1
165+
# valid_tokens_states found above
166+
additional_state_id = (
167+
max(list(self.states_to_token_maps.keys()) + list(self.final_states)) + 1
168+
)
155169
self.states_to_token_maps[additional_state_id] = self.states_to_token_maps[0]
156-
self.states_to_token_maps[0] = {**possible_tokens, last_token_id: additional_state_id}
157-
158-
return prompt[:-len(last_token_text)]
159-
170+
self.states_to_token_maps[0] = {}
171+
for token, state in valid_tokens_states:
172+
if state == 0:
173+
self.states_to_token_maps[0][token] = additional_state_id
174+
else:
175+
self.states_to_token_maps[0][token] = state
176+
return prompt[: -len(last_token_text)]
177+
178+
def find_valid_crossing_tokens(
179+
self, crossing_tokens: Dict[int, str], last_token_length: int
180+
) -> List[Tuple[int, int]]:
181+
"""For each crossing token, check that the characters after the boundary match the FSM
182+
and find the state it would lead to. Return the valid tokens with the associated state
183+
"""
184+
valid_tokens = []
185+
for token, text in crossing_tokens.items():
186+
is_valid = True
187+
crossing_text = text[last_token_length:]
188+
state = 0
189+
for char in crossing_text:
190+
char_token = self.vocabulary.get(char)
191+
try:
192+
state = self.states_to_token_maps[state][char_token] # type: ignore
193+
except KeyError:
194+
is_valid = False
195+
break
196+
if is_valid:
197+
valid_tokens.append((token, state))
198+
return valid_tokens
160199

161200
def allowed_token_ids(self, state: FSMState) -> List[int]:
162201
"""Generate a list of allowed tokens for the next step.
@@ -222,12 +261,7 @@ def is_final_state(self, state: FSMState) -> bool:
222261

223262
def copy(self) -> "RegexFSM":
224263
"""Create a copy of the FSM."""
225-
# temporary solution to the problem of unpickleable dict_values
226-
self.vocabulary = cloudpickle.dumps(self.vocabulary)
227-
copy = deepcopy(self)
228-
self.vocabulary = cloudpickle.loads(self.vocabulary)
229-
copy.vocabulary = cloudpickle.loads(copy.vocabulary)
230-
return copy
264+
return deepcopy(self)
231265

232266

233267
class CFGFSM(FSM):
@@ -257,6 +291,10 @@ def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
257291
self.done = False
258292
self.regex_fsm: RegexFSM
259293

294+
def align_prompt_tokens(self, prompt: str) -> str:
295+
"""Not implemented for CFGFSM"""
296+
return prompt
297+
260298
def _set_next_regex_fsm(self) -> None:
261299
"""Use the CFG incremental parser to set the next regex FSM.
262300
@@ -278,7 +316,6 @@ def _set_next_regex_fsm(self) -> None:
278316
self.allow_eos = True
279317
options.add("")
280318
assert len(options) > 1
281-
282319
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
283320
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
284321
self.reset_state = True

outlines/generate/api.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json as pyjson
22
import warnings
3+
from copy import deepcopy
34
from typing import Callable, Iterator, List, Optional, Tuple, Union
45

56
import torch
@@ -77,12 +78,15 @@ def get_generated_token_ids(
7778
return token_ids
7879

7980
def get_generated_sequences(
80-
self, generated_token_ids: List[torch.Tensor], initial_prompts: List[str], prompts: List[str]
81+
self,
82+
generated_token_ids: List[torch.Tensor],
83+
initial_prompts: List[str],
84+
prompts: List[str],
8185
) -> List[str]:
8286
"""Give the text sequences generated based on the tokens generated and the initial prompts"""
8387
generated_tokens_text = self.tokenizer.decode(generated_token_ids)
8488
return [
85-
generated_tokens_text[i][len(initial_prompts[i]) - len(prompts[i]):]
89+
generated_tokens_text[i][len(initial_prompts[i]) - len(prompts[i]) :]
8690
for i in range(len(generated_tokens_text))
8791
]
8892

@@ -196,7 +200,7 @@ def __call__(
196200

197201
if isinstance(prompts, str):
198202
prompts = [prompts]
199-
initial_prompts = copy.deepcopy(prompts)
203+
initial_prompts = deepcopy(prompts)
200204

201205
if isinstance(stop_at, str):
202206
stop_at = [stop_at]
@@ -205,7 +209,9 @@ def __call__(
205209
max_tokens = max_tokens or self.max_tokens
206210
num_sequences = len(prompts)
207211
fsms = [self.fsm.copy() for _ in prompts]
208-
prompts = [fsm.align_prompt_tokens(prompt) for fsm, prompt in zip(fsms, prompts)]
212+
prompts = [
213+
fsm.align_prompt_tokens(prompt) for fsm, prompt in zip(fsms, prompts)
214+
]
209215

210216
if rng is None:
211217
rng = torch.Generator(device=self.device)
@@ -239,7 +245,9 @@ def __call__(
239245
generated_token_ids = self.get_generated_token_ids(
240246
init_state, initial_prompts, last_state
241247
)
242-
generated = self.get_generated_sequences(generated_token_ids, initial_prompts, prompts)
248+
generated = self.get_generated_sequences(
249+
generated_token_ids, initial_prompts, prompts
250+
)
243251
stripped = [
244252
self.strip_stop_sequences(sequence, stop_sequences)
245253
for sequence in generated

0 commit comments

Comments
 (0)