Skip to content

Commit 4cfa9e1

Browse files
committed
Implement prompt token alignment in FSMLogitsProcessor
Draft
1 parent 91c7b3d commit 4cfa9e1

File tree

3 files changed

+169
-19
lines changed

3 files changed

+169
-19
lines changed

outlines/fsm/guide.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class Guide(Protocol):
7171
7272
"""
7373

74+
final_state: int = -1
75+
7476
def get_next_instruction(self, state: int) -> Instruction:
7577
...
7678

@@ -90,7 +92,6 @@ def copy(self) -> "Guide":
9092
class StopAtEOSGuide(Guide):
9193
"""Guide to generate tokens until the EOS token has been generated."""
9294

93-
final_state = -1
9495
start_state = 0
9596

9697
def __init__(self, tokenizer: "Tokenizer"):

outlines/generate/api.py

+117-8
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import datetime
22
from dataclasses import dataclass
3-
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
3+
from typing import Iterator, List, Optional, Sequence, Union
4+
5+
import torch
46

57
from outlines.generate.generator import sequence_generator
68
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler
79

8-
if TYPE_CHECKING:
9-
import torch
10-
1110
FormattedOutput = Union[
1211
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
1312
]
13+
TotalCompletionsType = Optional[Union[List[str], str]]
1414

1515

1616
class SequenceGenerator:
@@ -461,6 +461,47 @@ def prepare_generation_parameters(
461461

462462
return generation_params
463463

464+
def strip_completions(
465+
self,
466+
completions,
467+
prompts: Union[str, List[str]],
468+
aligned_prompts: Union[str, List[str]],
469+
):
470+
"""Remove characters generated through token alignment from the completions.
471+
472+
As token alignment makes the model re-generate some of the characters at
473+
the end of the prompt, we want to remove those from the beginning of the
474+
completions to only return the characters after the end of the user prompts.
475+
476+
Parameters
477+
----------
478+
completions
479+
Text generated by the model
480+
prompts
481+
The original prompts provided by the user
482+
aligned_prompts
483+
The prompts of the user after token alignment (what's given to the model)
484+
485+
Returns
486+
-------
487+
The stripped completions
488+
"""
489+
if isinstance(prompts, str):
490+
if isinstance(completions, str):
491+
return completions[len(prompts) - len(aligned_prompts) :]
492+
493+
return [
494+
self.strip_completions(completion, prompts, aligned_prompts)
495+
for completion in completions
496+
]
497+
498+
return [
499+
self.strip_completions(completion, prompt, aligned_prompt)
500+
for completion, prompt, aligned_prompt in zip(
501+
completions, prompts, aligned_prompts
502+
)
503+
]
504+
464505
def format_sequence(self, sequence: str) -> FormattedOutput:
465506
"""Translate the generated sequence to another type.
466507
@@ -500,15 +541,24 @@ def format(sequences):
500541
max_tokens, stop_at, seed
501542
)
502543

544+
aligned_prompts = self.logits_processor.align_prompts(prompts)
545+
503546
completions = self.model.generate(
504-
prompts,
547+
aligned_prompts,
505548
generation_params,
506549
self.logits_processor,
507550
self.sampling_params,
508551
**model_specific_params,
509552
)
510553

511-
return format(completions)
554+
print(completions, prompts, aligned_prompts)
555+
stripped_completions = self.strip_completions(
556+
completions, prompts, aligned_prompts
557+
)
558+
559+
print(stripped_completions)
560+
561+
return format(stripped_completions)
512562

513563
def stream(
514564
self,
@@ -519,13 +569,72 @@ def stream(
519569
**model_specific_params,
520570
):
521571
"""Return a text generator from a prompt or a list of prompts."""
572+
573+
def add_chunks_to_completions(
574+
text_chunks: Union[str, List[str], List[List[str]], Sequence[str]],
575+
total_completions: Optional[
576+
Union[str, List[str], List[List[str]], Sequence[str]]
577+
],
578+
):
579+
"""Append each of the text chunks at the end of the corresponding completions"""
580+
if isinstance(text_chunks, str):
581+
if isinstance(total_completions, str):
582+
return total_completions + text_chunks
583+
return text_chunks
584+
585+
if total_completions:
586+
return [
587+
add_chunks_to_completions(text_chunk, total_completion)
588+
for text_chunk, total_completion in zip(
589+
text_chunks, total_completions
590+
)
591+
]
592+
593+
return [
594+
add_chunks_to_completions(text_chunk, None)
595+
for text_chunk in text_chunks
596+
]
597+
598+
def strip_text_chunks(
599+
text_chunks: Union[str, List[str], List[List[str]], Sequence[str]],
600+
stripped_completions: Union[str, List[str], List[List[str]], Sequence[str]],
601+
):
602+
"""Get the stripped text_chunks from the stripped_completions."""
603+
if isinstance(text_chunks, str):
604+
return (
605+
stripped_completions[-len(text_chunks) :]
606+
if len(text_chunks) > 0
607+
else ""
608+
)
609+
610+
return [
611+
strip_text_chunks(text_chunk, stripped_completion)
612+
for text_chunk, stripped_completion in zip(
613+
text_chunks, stripped_completions
614+
)
615+
]
616+
522617
generation_params = self.prepare_generation_parameters(
523618
max_tokens, stop_at, seed
524619
)
525-
return self.model.stream(
620+
621+
aligned_prompts = self.logits_processor.align_prompts(prompts)
622+
623+
total_completions: TotalCompletionsType = None
624+
625+
for text_chunks in self.model.stream(
526626
prompts,
527627
generation_params,
528628
self.logits_processor,
529629
self.sampling_params,
530630
**model_specific_params,
531-
)
631+
):
632+
total_completions = add_chunks_to_completions(
633+
text_chunks, total_completions
634+
)
635+
636+
stripped_completions = self.strip_completions(
637+
total_completions, prompts, aligned_prompts
638+
)
639+
640+
yield strip_text_chunks(text_chunks, stripped_completions)

outlines/processors/structured.py

+50-10
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
6161
The finite state machine which is used to bias the logits.
6262
"""
6363
self.tokenizer = tokenizer
64-
self._fsm_states: Dict[int, int] = {}
64+
self._fsm_states: List[Dict[int, int]] = []
6565
self.fsm: Guide = fsm
66+
self._seq_fsms: List[Guide] = []
6667
self._is_first_token = True
6768
self._seq_start_idx: Optional[int] = None
6869

@@ -83,33 +84,72 @@ def process_logits(
8384
torch.Tensor
8485
The biased logits.
8586
"""
87+
samples = int(len(input_ids) / len(self._seq_fsms))
8688
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`
8789

8890
if self._is_first_token:
8991
self._is_first_token = False
9092
self._seq_start_idx = len(input_ids[0])
9193

92-
self._fsm_states = {hash(tuple([])): 0}
94+
self._fsm_states = [
95+
{hash(tuple([])): 0} for _ in range(len(self._seq_fsms))
96+
]
9397
sequence_states = [0] * len(input_ids)
9498

9599
else:
96-
for seq_ids in input_ids:
97-
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
98-
prev_state = self._fsm_states[prev_state_key]
100+
for i, seq_ids in enumerate(input_ids):
101+
try:
102+
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
103+
prev_state = self._fsm_states[i // samples][prev_state_key]
99104

100-
curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
101-
curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1])
105+
curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
106+
curr_state = self._seq_fsms[i // samples].get_next_state(
107+
prev_state, seq_ids[-1]
108+
)
102109

103-
self._fsm_states[curr_state_key] = curr_state
104-
sequence_states.append(curr_state)
110+
self._fsm_states[i // samples][curr_state_key] = curr_state
111+
sequence_states.append(curr_state)
112+
113+
# This exception happens after the sequence generation is finished with bean search
114+
except KeyError:
115+
sequence_states.append(self._seq_fsms[i // samples].final_state)
105116

106117
mask = torch.full_like(logits, -math.inf)
107118
for i, fsm_state in enumerate(sequence_states):
108-
allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens
119+
allowed_tokens = (
120+
self._seq_fsms[i // samples].get_next_instruction(fsm_state).tokens
121+
)
109122
mask[i, allowed_tokens] = logits[i, allowed_tokens]
110123

111124
return mask
112125

126+
def align_prompts(self, prompts: Union[str, List[str]]) -> Union[str, List[str]]:
127+
"""Create a distinct fsm for each prompt. Apply prompt alignment to each of them.
128+
If applicable, prompt alignment shortens the user prompt and updates the fsm accordingly.
129+
130+
Parameters
131+
----------
132+
prompts
133+
The text prompts previded by the user
134+
135+
Returns
136+
-------
137+
The initial text prompts after application of prompt alignment
138+
"""
139+
is_input_str = isinstance(prompts, str)
140+
if isinstance(prompts, str):
141+
prompts = [prompts]
142+
143+
self._seq_fsms = [self.fsm.copy() for _ in range(len(prompts))]
144+
aligned_prompts = [
145+
fsm.align_prompt_tokens(prompt, self.tokenizer)
146+
for fsm, prompt in zip(self._seq_fsms, prompts)
147+
]
148+
149+
if is_input_str:
150+
return aligned_prompts[0]
151+
return aligned_prompts
152+
113153
def copy(self) -> "FSMLogitsProcessor":
114154
"""Return a copy of the logits processor."""
115155
return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())

0 commit comments

Comments
 (0)