Skip to content

Commit 6bb90f8

Browse files
committed
Add unit tests for token alignment
1 parent 4aa74f2 commit 6bb90f8

File tree

4 files changed

+332
-45
lines changed

4 files changed

+332
-45
lines changed

outlines/fsm/fsm.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def find_crossing_tokens(
423423
) -> Dict[int, List[int]]:
424424
"""Find the tokens that could replace one or more tokens at the end of token_ids
425425
while conserving the same intial text (and extending it by at least one character).
426-
Return a dictionary with, for the indexes in the token_ids, the associated crossing tokens.
426+
Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
427427
"""
428428
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
429429
len_token_ids = len(token_ids)
@@ -441,7 +441,8 @@ def find_crossing_tokens(
441441
if text.startswith(characters_considered)
442442
and len(text) > len(characters_considered)
443443
]
444-
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids
444+
if crossing_token_ids:
445+
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids
445446

446447
return crossing_tokens_map
447448

outlines/generate/api.py

+54-41
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Iterator, List, Optional, Union
1+
from typing import Iterator, List, Optional, Tuple, Union
22

33
import torch
44

5-
from outlines.fsm.fsm import FSMState
5+
from outlines.fsm.fsm import FSM, FSMState
66
from outlines.generate.generator import sequence_generator
77

88

@@ -21,6 +21,53 @@ def __init__(
2121
self.device = device
2222
self.num_samples = sampler.samples
2323

24+
def align_prompt_tokens(
25+
self,
26+
prompt_token_ids: torch.Tensor,
27+
attention_masks: torch.Tensor,
28+
fsms: List[FSM],
29+
) -> Tuple[torch.Tensor, torch.Tensor]:
30+
"""Implement token alignment for each fsm. Return the updated tokens_ids and attention_masks"""
31+
aligned_prompts, aligned_masks = zip(
32+
*[
33+
fsm.align_prompt_tokens(prompt, mask)
34+
for prompt, mask, fsm in zip(prompt_token_ids, attention_masks, fsms)
35+
]
36+
)
37+
# We have to pad some of the prompts if they are not all of the same length after this operation
38+
max_length_aligned_prompt = max(prompt.shape[0] for prompt in aligned_prompts)
39+
padded_aligned_prompts = [
40+
torch.cat(
41+
[
42+
torch.full(
43+
(max_length_aligned_prompt - prompt.shape[0],),
44+
0,
45+
device=prompt_token_ids.device,
46+
dtype=prompt.dtype,
47+
),
48+
prompt,
49+
]
50+
)
51+
for prompt in aligned_prompts
52+
]
53+
padded_aligned_masks = [
54+
torch.cat(
55+
[
56+
torch.full(
57+
(max_length_aligned_prompt - mask.shape[0],),
58+
0,
59+
device=prompt_token_ids.device,
60+
dtype=mask.dtype,
61+
),
62+
mask,
63+
]
64+
)
65+
for mask in aligned_masks
66+
]
67+
aligned_prompt_token_ids = torch.stack(padded_aligned_prompts)
68+
aligned_attention_masks = torch.stack(padded_aligned_masks)
69+
return aligned_prompt_token_ids, aligned_attention_masks
70+
2471
def get_generated_token_ids(
2572
self,
2673
prompt_token_ids: torch.Tensor,
@@ -189,49 +236,15 @@ def __call__(
189236
num_samples = self.num_samples
190237
batch_size = len(prompts)
191238

192-
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
193-
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
194-
195239
prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
196240
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
197241

198-
# Token alignment may shorten some of the prompts by removing tokens at their end.
199-
# We have to pad some of the prompts if they are not all of the same length after this operation
200-
aligned_prompts, aligned_masks = zip(
201-
*[
202-
fsm.align_prompt_tokens(prompt, mask)
203-
for prompt, mask, fsm in zip(prompt_token_ids, attention_masks, fsms)
204-
]
242+
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
243+
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
244+
245+
aligned_prompt_token_ids, aligned_attention_masks = self.align_prompt_tokens(
246+
prompt_token_ids, attention_masks, fsms
205247
)
206-
max_length_aligned_prompt = max(prompt.shape[0] for prompt in aligned_prompts)
207-
padded_aligned_prompts = [
208-
torch.cat(
209-
[
210-
torch.full(
211-
(max_length_aligned_prompt - prompt.shape[0],),
212-
0,
213-
dtype=prompt.dtype,
214-
),
215-
prompt,
216-
]
217-
)
218-
for prompt in aligned_prompts
219-
]
220-
padded_aligned_masks = [
221-
torch.cat(
222-
[
223-
torch.full(
224-
(max_length_aligned_prompt - mask.shape[0],),
225-
0,
226-
dtype=mask.dtype,
227-
),
228-
mask,
229-
]
230-
)
231-
for mask in aligned_masks
232-
]
233-
aligned_prompt_token_ids = torch.stack(padded_aligned_prompts)
234-
aligned_attention_masks = torch.stack(padded_aligned_masks)
235248

236249
weights = torch.zeros(
237250
(batch_size * num_samples), dtype=torch.float, device=self.device

0 commit comments

Comments
 (0)