Skip to content

Commit bfa0e94

Browse files
committed
Stop generation with Continuation when a specific string was generated
1 parent 0bdcc56 commit bfa0e94

File tree

5 files changed

+167
-27
lines changed

5 files changed

+167
-27
lines changed
+64-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, Union
22

33
import torch
44

@@ -17,36 +17,89 @@ class Continuation(Sequence):
1717
1818
"""
1919

20-
def __init__(self, model, max_tokens: Optional[int]):
20+
def __init__(
21+
self, model, max_tokens: Optional[int] = None, stop: Union[str, List[str]] = []
22+
):
2123
super().__init__(model, max_tokens)
2224
self.eos_token_id = torch.tensor(
2325
[self.model.tokenizer.eos_token_id], device=self.device
2426
)
2527

28+
if isinstance(stop, str):
29+
stop = [stop]
30+
31+
self.stop_sequences = stop
32+
2633
def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor:
2734
"""Determine whether the sequences reached maximum length of end with
2835
and EOS token.
2936
30-
In practice, `Sequence`'s `__call__` methods only passed the `token_ids`
31-
of the sequences that haven't been marked as finished already, which is
32-
why we only need to look for the EOS token in the last element rather
33-
than in the whole sequence.
37+
We only need to look for the EOS token in the last element rather than
38+
in the whole sequence. Indeed, (1) EOS is a single token (2)
39+
`Sequence`'s `__call__` methods only passed the `token_ids` of the
40+
sequences that haven't been marked as finished already.
3441
3542
Parameters
3643
----------
3744
token_ids
3845
The input sequences.
3946
4047
"""
41-
return token_ids[:, -1] == self.model.tokenizer.eos_token_id
48+
49+
sequences = self.model.tokenizer.decode(token_ids)
50+
contains_stop_sequence = []
51+
for sequence in sequences:
52+
found = False
53+
for stop_str in self.stop_sequences:
54+
if stop_str in sequence:
55+
found = True
56+
57+
contains_stop_sequence.append(found)
58+
59+
contains_stop_sequence = torch.tensor(contains_stop_sequence, dtype=torch.bool)
60+
contains_eos = token_ids[:, -1] == self.model.tokenizer.eos_token_id
61+
62+
return torch.logical_or(contains_eos, contains_stop_sequence)
4263

4364
def postprocess_completions(self, completions: List[str]) -> List[str]:
44-
"""Remove the EOS token from the completion."""
45-
return [
65+
"""Remove the EOS token from the completion.
66+
67+
Sequences in `stop` take precedence over EOS. For instance, if
68+
`stop=["\n"]` and the generated sequence is 'One\nTwo<EOS>`
69+
`Continuation.postprocess_completions` will return `One`.
70+
71+
"""
72+
completions_without_eos = [
4673
completion.replace(self.model.tokenizer.eos_token, "")
4774
for completion in completions
4875
]
4976

77+
completions_without_stop = []
78+
for completion in completions_without_eos:
79+
for stop_str in self.stop_sequences:
80+
idx = completion.rfind(stop_str) # ignore the prompt
81+
if idx > 0:
82+
completion = completion[:idx]
83+
84+
completions_without_stop.append(completion)
85+
86+
return completions_without_stop
87+
5088

51-
def continuation(model, max_tokens: Optional[int] = None):
52-
return Continuation(model, max_tokens)
89+
def continuation(
90+
model, max_tokens: Optional[int] = None, *, stop: Union[str, List[str]] = []
91+
):
92+
"""Generate text sequences.
93+
94+
Parameters
95+
----------
96+
model
97+
The model to use to computes the next-token logits.
98+
max_tokens
99+
The maximum number of tokens to generate.
100+
stop
101+
A string or list of strings which, when generated, stops
102+
the generation for this sequence.
103+
104+
"""
105+
return Continuation(model, max_tokens, stop)

outlines/text/generate/regex.py

+30
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,18 @@ def create_proposal(
135135

136136

137137
def regex(model, regex_string: str, max_tokens: Optional[int] = None):
138+
"""Generate text sequences that match the input regex.
139+
140+
Parameters
141+
----------
142+
model
143+
The model to use to computes the next-token logits.
144+
regex
145+
The regular expression generated expressions must match.
146+
max_tokens
147+
The maximum number of tokens to generate.
148+
149+
"""
138150
return Regex(model, regex_string, max_tokens)
139151

140152

@@ -145,6 +157,15 @@ def integer(model, max_tokens: Optional[int] = None):
145157
signs and forbids leading zeros (even if the `int` function in Python allows
146158
them).
147159
160+
Parameters
161+
----------
162+
model
163+
The model to use to computes the next-token logits.
164+
regex
165+
The regular expression generated expressions must match.
166+
max_tokens
167+
The maximum number of tokens to generate.
168+
148169
"""
149170
return Regex(model, r"[-+]?\d+", max_tokens)
150171

@@ -156,5 +177,14 @@ def float(model, max_tokens: Optional[int] = None):
156177
signs, and forbids leading zeros (even if the `float` function in Python
157178
allows them).
158179
180+
Parameters
181+
----------
182+
model
183+
The model to use to computes the next-token logits.
184+
regex
185+
The regular expression generated expressions must match.
186+
max_tokens
187+
The maximum number of tokens to generate.
188+
159189
"""
160190
return Regex(model, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", max_tokens)

outlines/text/generate/sequence.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def __call__(
229229
)
230230
token_ids = self.update_token_ids(is_finished, token_ids, updated_token_ids)
231231
attention_mask = self.expand_attention_mask(attention_mask)
232-
is_finished[~is_finished] = self.is_finished(updated_token_ids).flatten()
232+
is_finished[~is_finished] = self.is_finished(
233+
updated_token_ids[:, num_prompt_tokens:]
234+
).flatten()
233235

234236
result = self.model.tokenizer.decode(token_ids)
235237
result = self.postprocess_completions(result)
+63-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import numpy as np
2-
from numpy.testing import assert_array_equal
1+
import torch
32

43
from outlines.text.generate.continuation import Continuation, continuation
54

@@ -9,35 +8,87 @@ class Tokenizer:
98
eos_token_id = 0
109
pad_token_id = -1
1110

11+
def decode(self, token_ids):
12+
return ["Test"] * token_ids.shape[0]
13+
1214

1315
class Model:
1416
tokenizer = Tokenizer()
1517
device = "cpu"
1618

1719

18-
def test_continuation_is_finished():
19-
model = continuation(Model(), 10)
20+
def test_continuation_eos_is_finished():
21+
model = continuation(Model())
2022
assert isinstance(model, Continuation)
2123

22-
token_ids = np.array([[3, 2]])
24+
token_ids = torch.tensor([[3, 2]])
2325
result = model.is_finished(token_ids)
24-
assert_array_equal(result, [False])
26+
assert torch.equal(result, torch.tensor([False]))
2527

26-
token_ids = np.array([[3, 2, 0]])
28+
token_ids = torch.tensor([[3, 2, 0]])
2729
result = model.is_finished(token_ids)
28-
assert_array_equal(result, [True])
30+
assert torch.equal(result, torch.tensor([True]))
2931

30-
token_ids = np.array([[3, 2, 1], [3, 2, 0]])
32+
token_ids = torch.tensor([[3, 2, 1], [3, 2, 0]])
3133
result = model.is_finished(token_ids)
32-
assert_array_equal(result, [False, True])
34+
assert torch.equal(result, torch.tensor([False, True]))
3335

34-
token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]])
36+
token_ids = torch.tensor([[3, 2, 1, 0], [3, 2, 0, -1]])
3537
result = model.is_finished(token_ids)
36-
assert_array_equal(result, [True, False])
38+
assert torch.equal(result, torch.tensor([True, False]))
3739

3840

3941
def test_continuation_postprocess():
4042
model = continuation(Model())
4143
result = model.postprocess_completions(["Here<EOS>"])
4244
assert len(result) == 1
4345
assert result[0] == "Here"
46+
47+
48+
def test_continuation_stop_is_finished():
49+
tokenizer = Tokenizer()
50+
tokenizer.decode = lambda x: ["finished \n", "not_finished"]
51+
model = Model()
52+
model.tokenizer = tokenizer
53+
54+
model = continuation(model, stop=["\n"])
55+
56+
token_ids = torch.tensor([[2, 3]])
57+
result = model.is_finished(token_ids)
58+
assert torch.equal(result, torch.tensor([True, False]))
59+
60+
61+
def test_continuation_stop_postprocess():
62+
model = Continuation(Model(), stop="\n")
63+
result = model.postprocess_completions(["Stop\n"])
64+
assert len(result) == 1
65+
assert result[0] == "Stop"
66+
67+
model = Continuation(Model(), stop=["\n", ","])
68+
result = model.postprocess_completions(["Stop"])
69+
assert len(result) == 1
70+
assert result[0] == "Stop"
71+
72+
result = model.postprocess_completions(["Stop\n"])
73+
assert len(result) == 1
74+
assert result[0] == "Stop"
75+
76+
result = model.postprocess_completions(["Stop\naaa"])
77+
assert len(result) == 1
78+
assert result[0] == "Stop"
79+
80+
result = model.postprocess_completions(["Stop,aa\naaa"])
81+
assert len(result) == 1
82+
assert result[0] == "Stop"
83+
84+
result = model.postprocess_completions(["Stop\naa,a"])
85+
assert len(result) == 1
86+
assert result[0] == "Stop"
87+
88+
result = model.postprocess_completions(["Stop\n", "Nonstop"])
89+
assert len(result) == 2
90+
assert result == ["Stop", "Nonstop"]
91+
92+
result = model.postprocess_completions(["StopHere\nNoHere<EOS>"])
93+
assert len(result) == 1
94+
assert result[0] == "StopHere"

tests/text/generate/test_integration_transfomers.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@ def test_transformers_integration_continuation():
1313

1414
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
1515
model = models.transformers(model_name, device="cpu")
16-
sequence = generate.continuation(model)("Write a short sentence", rng=rng)
16+
sequence = generate.continuation(model)("Write a short sentence ", rng=rng)
1717
assert isinstance(sequence, str)
1818
assert model.tokenizer.eos_token not in sequence
1919

2020
sequence = generate.continuation(model, max_tokens=10)(
21-
"Write a short sentence", rng=rng
21+
"Write a short sentence ", rng=rng
2222
)
2323
assert isinstance(sequence, str)
2424

25-
prompts = ["Write a short sentence", "And another one"]
25+
prompts = ["Write a short sentence ", "And another one "]
2626
sequence = generate.continuation(model, max_tokens=10)(prompts, rng=rng)
2727
assert isinstance(sequence, list)
2828
assert len(sequence) == 2
2929
assert isinstance(sequence[0], str)
3030

31+
prompt = "Write a short sentence "
32+
sequence = generate.continuation(model, stop="a")(prompt, rng=rng)
33+
assert sequence[len(prompt) :].find("a") == -1
34+
3135

3236
@pytest.mark.xfail
3337
def test_transformers_integration_continuation_array_samples():

0 commit comments

Comments
 (0)