Skip to content

Commit ec281ef

Browse files
committed
Stop generation with Continuation when a specific string was generated
1 parent 34bc2fb commit ec281ef

File tree

3 files changed

+66
-18
lines changed

3 files changed

+66
-18
lines changed

outlines/text/generate/continuation.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ class Continuation(Sequence):
1717
1818
"""
1919

20-
def __init__(self, model, max_tokens: Optional[int]):
20+
def __init__(self, model, stop: List[str], max_tokens: Optional[int]):
2121
super().__init__(model, max_tokens)
2222
self.eos_token_id = torch.tensor(
2323
[self.model.tokenizer.eos_token_id], device=self.device
2424
)
25+
self.stop_sequences = stop
2526

2627
def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor:
2728
"""Determine whether the sequences reached maximum length of end with
@@ -38,15 +39,41 @@ def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor:
3839
The input sequences.
3940
4041
"""
41-
return token_ids[:, -1] == self.model.tokenizer.eos_token_id
42+
43+
sequences = self.model.tokenizer.decode(token_ids)
44+
is_stop_sequence_found = []
45+
for sequence in sequences:
46+
found = False
47+
for stop_str in self.stop_sequences:
48+
if stop_str in sequence:
49+
found = True
50+
51+
is_stop_sequence_found.append(found)
52+
53+
is_stop_sequence_found = torch.tensor(is_stop_sequence_found, dtype=torch.bool)
54+
is_eos_found = token_ids[:, -1] == self.model.tokenizer.eos_token_id
55+
56+
return torch.logical_or(is_eos_found, is_stop_sequence_found)
4257

4358
def postprocess_completions(self, completions: List[str]) -> List[str]:
4459
"""Remove the EOS token from the completion."""
45-
return [
60+
without_eos = [
4661
completion.replace(self.model.tokenizer.eos_token, "")
4762
for completion in completions
4863
]
4964

65+
completions = []
66+
for completion in without_eos:
67+
for stop_str in self.stop_sequences:
68+
idx = completion.find(stop_str)
69+
if idx > 0:
70+
completions.append(completion[:idx])
71+
break
72+
73+
completions.append(completion)
74+
75+
return completions
76+
5077

51-
def continuation(model, max_tokens: Optional[int] = None):
52-
return Continuation(model, max_tokens)
78+
def continuation(model, *, stop: List[str] = [], max_tokens: Optional[int] = None):
79+
return Continuation(model, stop, max_tokens)

outlines/text/generate/sequence.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def __call__(
229229
)
230230
token_ids = self.update_token_ids(is_finished, token_ids, updated_token_ids)
231231
attention_mask = self.expand_attention_mask(attention_mask)
232-
is_finished[~is_finished] = self.is_finished(updated_token_ids).flatten()
232+
is_finished[~is_finished] = self.is_finished(
233+
updated_token_ids[:, num_prompt_tokens:]
234+
).flatten()
233235

234236
result = self.model.tokenizer.decode(token_ids)
235237
result = self.postprocess_completions(result)
+31-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import numpy as np
2-
from numpy.testing import assert_array_equal
1+
import torch
32

43
from outlines.text.generate.continuation import Continuation, continuation
54

@@ -9,35 +8,55 @@ class Tokenizer:
98
eos_token_id = 0
109
pad_token_id = -1
1110

11+
def decode(self, token_ids):
12+
return ["Test"] * token_ids.shape[0]
13+
1214

1315
class Model:
1416
tokenizer = Tokenizer()
1517
device = "cpu"
1618

1719

18-
def test_continuation_is_finished():
19-
model = continuation(Model(), 10)
20+
def test_continuation_eos_is_finished():
21+
model = continuation(Model())
2022
assert isinstance(model, Continuation)
2123

22-
token_ids = np.array([[3, 2]])
24+
token_ids = torch.tensor([[3, 2]])
2325
result = model.is_finished(token_ids)
24-
assert_array_equal(result, [False])
26+
assert torch.equal(result, torch.tensor([False]))
2527

26-
token_ids = np.array([[3, 2, 0]])
28+
token_ids = torch.tensor([[3, 2, 0]])
2729
result = model.is_finished(token_ids)
28-
assert_array_equal(result, [True])
30+
assert torch.equal(result, torch.tensor([True]))
2931

30-
token_ids = np.array([[3, 2, 1], [3, 2, 0]])
32+
token_ids = torch.tensor([[3, 2, 1], [3, 2, 0]])
3133
result = model.is_finished(token_ids)
32-
assert_array_equal(result, [False, True])
34+
assert torch.equal(result, torch.tensor([False, True]))
3335

34-
token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]])
36+
token_ids = torch.tensor([[3, 2, 1, 0], [3, 2, 0, -1]])
3537
result = model.is_finished(token_ids)
36-
assert_array_equal(result, [True, False])
38+
assert torch.equal(result, torch.tensor([True, False]))
3739

3840

3941
def test_continuation_postprocess():
4042
model = continuation(Model())
4143
result = model.postprocess_completions(["Here<EOS>"])
4244
assert len(result) == 1
4345
assert result[0] == "Here"
46+
47+
48+
def test_continuation_stop_is_finished():
49+
tokenizer = Tokenizer()
50+
tokenizer.decode = lambda x: ["finished \n", "not_finished"]
51+
model = Model()
52+
model.tokenizer = tokenizer
53+
54+
model = continuation(model, stop=["\n"])
55+
56+
token_ids = torch.tensor([[2, 3]])
57+
result = model.is_finished(token_ids)
58+
assert torch.equal(result, torch.tensor([True, False]))
59+
60+
61+
def test_continuation_stop_postprocess():
62+
assert False

0 commit comments

Comments
 (0)