Skip to content

Commit 4f640c4

Browse files
committed
Implement prompt token alignment in FSMLogitsProcessor and in SequenceGeneratorAdapter
1 parent 017597a commit 4f640c4

File tree

3 files changed

+215
-24
lines changed

3 files changed

+215
-24
lines changed

outlines/generate/api.py

+124-8
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import datetime
22
from dataclasses import dataclass
3-
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
3+
from typing import Iterator, List, Optional, Sequence, Union
4+
5+
import torch
46

57
from outlines.generate.generator import sequence_generator
68
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler
79

8-
if TYPE_CHECKING:
9-
import torch
10-
1110
FormattedOutput = Union[
1211
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
1312
]
13+
TotalCompletionsType = Optional[Union[List[str], str]]
1414

1515

1616
class SequenceGenerator:
@@ -461,6 +461,47 @@ def prepare_generation_parameters(
461461

462462
return generation_params
463463

464+
def strip_completions(
465+
self,
466+
completions,
467+
prompts: Union[str, List[str]],
468+
aligned_prompts: Union[str, List[str]],
469+
):
470+
"""Remove characters generated through token alignment from the completions.
471+
472+
As token alignment makes the model re-generate some of the characters at
473+
the end of the prompt, we want to remove those from the beginning of the
474+
completions to only return the characters after the end of the user prompts.
475+
476+
Parameters
477+
----------
478+
completions
479+
Text generated by the model
480+
prompts
481+
The original prompts provided by the user
482+
aligned_prompts
483+
The prompts of the user after token alignment (what's given to the model)
484+
485+
Returns
486+
-------
487+
The stripped completions
488+
"""
489+
if isinstance(prompts, str):
490+
if isinstance(completions, str):
491+
return completions[len(prompts) - len(aligned_prompts) :]
492+
493+
return [
494+
self.strip_completions(completion, prompts, aligned_prompts)
495+
for completion in completions
496+
]
497+
498+
return [
499+
self.strip_completions(completion, prompt, aligned_prompt)
500+
for completion, prompt, aligned_prompt in zip(
501+
completions, prompts, aligned_prompts
502+
)
503+
]
504+
464505
def format_sequence(self, sequence: str) -> FormattedOutput:
465506
"""Translate the generated sequence to another type.
466507
@@ -485,6 +526,7 @@ def __call__(
485526
max_tokens: Optional[int] = None,
486527
stop_at: Optional[Union[str, List[str]]] = None,
487528
seed: Optional[int] = None,
529+
token_healing_enabled=True,
488530
**model_specific_params,
489531
):
490532
"""Generate text from a prompt of list of prompts."""
@@ -500,32 +542,106 @@ def format(sequences):
500542
max_tokens, stop_at, seed
501543
)
502544

545+
# if token_healing is disabled or unavailable for the type of fsm used by the processor,
546+
# the aligned_prompts are just the prompts
547+
aligned_prompts = self.logits_processor.setup_processor(
548+
prompts, token_healing_enabled
549+
)
550+
503551
completions = self.model.generate(
504-
prompts,
552+
aligned_prompts,
505553
generation_params,
506554
self.logits_processor,
507555
self.sampling_params,
508556
**model_specific_params,
509557
)
510558

511-
return format(completions)
559+
stripped_completions = self.strip_completions(
560+
completions, prompts, aligned_prompts
561+
)
562+
563+
return format(stripped_completions)
512564

513565
def stream(
514566
self,
515567
prompts: Union[str, List[str]],
516568
max_tokens: Optional[int] = None,
517569
stop_at: Optional[Union[str, List[str]]] = None,
518570
seed: Optional[int] = None,
571+
token_healing_enabled=True,
519572
**model_specific_params,
520573
):
521574
"""Return a text generator from a prompt or a list of prompts."""
575+
576+
def add_chunks_to_completions(
577+
text_chunks: Union[str, List[str], List[List[str]], Sequence[str]],
578+
total_completions: Optional[
579+
Union[str, List[str], List[List[str]], Sequence[str]]
580+
],
581+
):
582+
"""Append each of the text chunks at the end of the corresponding completions"""
583+
if isinstance(text_chunks, str):
584+
if isinstance(total_completions, str):
585+
return total_completions + text_chunks
586+
return text_chunks
587+
588+
if total_completions:
589+
return [
590+
add_chunks_to_completions(text_chunk, total_completion)
591+
for text_chunk, total_completion in zip(
592+
text_chunks, total_completions
593+
)
594+
]
595+
596+
return [
597+
add_chunks_to_completions(text_chunk, None)
598+
for text_chunk in text_chunks
599+
]
600+
601+
def strip_text_chunks(
602+
text_chunks: Union[str, List[str], List[List[str]], Sequence[str]],
603+
stripped_completions: Union[str, List[str], List[List[str]], Sequence[str]],
604+
):
605+
"""Get the stripped text_chunks from the stripped_completions."""
606+
if isinstance(text_chunks, str):
607+
return (
608+
stripped_completions[-len(text_chunks) :]
609+
if len(text_chunks) > 0
610+
else ""
611+
)
612+
613+
return [
614+
strip_text_chunks(text_chunk, stripped_completion)
615+
for text_chunk, stripped_completion in zip(
616+
text_chunks, stripped_completions
617+
)
618+
]
619+
522620
generation_params = self.prepare_generation_parameters(
523621
max_tokens, stop_at, seed
524622
)
525-
return self.model.stream(
623+
624+
# if token_healing is disabled or unavailable for the type of fsm used by the processor,
625+
# the aligned_prompts are just the prompts
626+
aligned_prompts = self.logits_processor.setup_processor(
627+
prompts, token_healing_enabled
628+
)
629+
630+
total_completions: TotalCompletionsType = None
631+
632+
for text_chunks in self.model.stream(
526633
prompts,
527634
generation_params,
528635
self.logits_processor,
529636
self.sampling_params,
530637
**model_specific_params,
531-
)
638+
):
639+
total_completions = add_chunks_to_completions(
640+
text_chunks, total_completions
641+
)
642+
643+
stripped_completions = self.strip_completions(
644+
total_completions, prompts, aligned_prompts
645+
)
646+
647+
yield strip_text_chunks(text_chunks, stripped_completions)

outlines/processors/structured.py

+75-15
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@
2424
limitations under the License.
2525
"""
2626
import math
27-
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
27+
from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeGuard, Union
2828

2929
import torch
3030
from pydantic import BaseModel
3131

32-
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide, StopAtEOSGuide
32+
from outlines.fsm.guide import (
33+
CFGGuide,
34+
Guide,
35+
RegexGuide,
36+
StopAtEOSGuide,
37+
TokenHealerMixin,
38+
)
3339
from outlines.fsm.json_schema import build_regex_from_schema
3440
from outlines.integrations.utils import convert_json_schema_to_str
3541

@@ -61,8 +67,10 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
6167
The finite state machine which is used to bias the logits.
6268
"""
6369
self.tokenizer = tokenizer
64-
self._fsm_states: Dict[int, int] = {hash(tuple([])): 0}
70+
self._fsm_states: List[Dict[int, int]] = []
6571
self.fsm: Guide = fsm
72+
self._seq_fsms: List[Guide] = []
73+
self._is_first_token = True
6674
self._seq_start_idx: Optional[int] = None
6775

6876
def process_logits(
@@ -82,36 +90,87 @@ def process_logits(
8290
torch.Tensor
8391
The biased logits.
8492
"""
85-
if self._seq_start_idx is None:
93+
samples = int(len(input_ids) / len(self._seq_fsms))
94+
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`
95+
96+
if self._is_first_token:
97+
self._is_first_token = False
8698
self._seq_start_idx = len(input_ids[0])
8799

88-
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`
100+
self._fsm_states = [
101+
{hash(tuple([])): 0} for _ in range(len(self._seq_fsms))
102+
]
103+
sequence_states = [0] * len(input_ids)
89104

90-
for seq_ids in input_ids:
91-
gen_ids = seq_ids[self._seq_start_idx :]
92-
curr_state_key = hash(tuple(gen_ids))
105+
else:
106+
for i, seq_ids in enumerate(input_ids):
107+
try:
108+
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
109+
prev_state = self._fsm_states[i // samples][prev_state_key]
93110

94-
if curr_state_key not in self._fsm_states:
95-
prev_state = self._fsm_states[hash(tuple(gen_ids[:-1]))]
96-
curr_state = self.fsm.get_next_state(prev_state, gen_ids[-1])
97-
self._fsm_states[curr_state_key] = curr_state
111+
curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
112+
curr_state = self._seq_fsms[i // samples].get_next_state(
113+
prev_state, seq_ids[-1]
114+
)
98115

99-
sequence_states.append(self._fsm_states[curr_state_key])
116+
self._fsm_states[i // samples][curr_state_key] = curr_state
117+
sequence_states.append(curr_state)
118+
119+
# This exception happens after the sequence generation is finished with bean search
120+
except KeyError:
121+
sequence_states.append(self._seq_fsms[i // samples].final_state)
100122

101123
mask = torch.full_like(logits, -math.inf)
102124
for i, fsm_state in enumerate(sequence_states):
103-
allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens
125+
allowed_tokens = (
126+
self._seq_fsms[i // samples].get_next_instruction(fsm_state).tokens
127+
)
104128
mask[i, allowed_tokens] = logits[i, allowed_tokens]
105129

106130
return mask
107131

132+
def setup_processor(
133+
self, prompts: Union[str, List[str]], token_healing_enabled: bool
134+
) -> Union[str, List[str]]:
135+
"""Prepare the processor to process logits for a specific set of prompts. Create a distinct
136+
fsm for each prompt. If selected and available, apply prompt alignment to each fsm.
137+
138+
Parameters
139+
----------
140+
prompts
141+
The text prompts previded by the user
142+
143+
Returns
144+
-------
145+
The initial prompts after application of prompt alignment if selected and available,
146+
the initial prompts unchanged otherwise.
147+
"""
148+
is_input_str = isinstance(prompts, str)
149+
if isinstance(prompts, str):
150+
prompts = [prompts]
151+
152+
self._seq_fsms = [self.fsm.copy() for _ in range(len(prompts))]
153+
154+
if isinstance(self.fsm, TokenHealerMixin) and token_healing_enabled:
155+
aligned_prompts = [
156+
fsm.align_prompt_tokens(prompt) # type: ignore
157+
for fsm, prompt in zip(self._seq_fsms, prompts)
158+
]
159+
else:
160+
aligned_prompts = prompts
161+
162+
if is_input_str:
163+
return aligned_prompts[0]
164+
return aligned_prompts
165+
108166
def copy(self) -> "FSMLogitsProcessor":
109167
"""Return a copy of the logits processor."""
110168
return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())
111169

112170

113171
class TextLogitsProcessor(FSMLogitsProcessor):
114172
"""Bias generation for free text (required because of prompt alignment).
173+
115174
Attributes
116175
----------
117176
tokenizer
@@ -122,6 +181,7 @@ class TextLogitsProcessor(FSMLogitsProcessor):
122181

123182
def __init__(self, tokenizer: "Tokenizer"):
124183
"""Compile the FSM that drives the regex-guided generation.
184+
125185
Parameters
126186
----------
127187
tokenizer
@@ -213,4 +273,4 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
213273
The tokenizer used to convert tokens to ids.
214274
"""
215275
cfg_automata = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer)
216-
super().__init__(tokenizer=tokenizer, fsm=cfg_automata)
276+
super().__init__(tokenizer=tokenizer, fsm=cfg_automata)

tests/generate/test_generator.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def test_sequence_generator_class():
2121
class MockFSM:
2222
first_state = 0
2323

24+
def align_prompt_tokens(self, prompt):
25+
return prompt
26+
2427
def get_next_state(self, state, next_token_ids):
2528
return 4
2629

@@ -39,7 +42,7 @@ def encode(self, _):
3942
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])
4043

4144
def decode(self, tokens):
42-
return ["testx"[i] for i in tokens]
45+
return ["".join(["testx"[int(i)] for i in tokens[0]])]
4346

4447
class MockModel:
4548
def __init__(self):
@@ -77,6 +80,9 @@ def __call__(self, biased_logits, *_):
7780

7881
def test_sequence_generator_1d_single_iteration():
7982
class MockFSM:
83+
def align_prompt_tokens(self, prompt):
84+
return prompt
85+
8086
def get_next_state(self, state, next_token_ids):
8187
return 0
8288

@@ -132,6 +138,9 @@ def sampler(biased_logits, *_):
132138

133139
def test_sequence_generator_1d_several_iterations():
134140
class MockFSM:
141+
def align_prompt_tokens(self, prompt):
142+
return prompt
143+
135144
def get_next_state(self, state, next_token_ids):
136145
return state + 1
137146

@@ -194,6 +203,9 @@ def sampler(biased_logits, *_):
194203

195204
def test_sequence_generator_2d_single_iteration():
196205
class MockFSM:
206+
def align_prompt_tokens(self, prompt):
207+
return prompt
208+
197209
def get_next_state(self, state, next_token_ids):
198210
return 0
199211

@@ -260,6 +272,9 @@ def sampler(biased_logits, *_):
260272

261273
def test_sequence_generator_2d_several_iterations():
262274
class MockFSM:
275+
def align_prompt_tokens(self, prompt):
276+
return prompt
277+
263278
def get_next_state(self, state, next_token_ids):
264279
return state + 1
265280

0 commit comments

Comments
 (0)