Skip to content

Commit 7d2dcce

Browse files
authored
Support per-request seed (vllm-project#2514)
1 parent dc903e7 commit 7d2dcce

File tree

10 files changed

+289
-84
lines changed

10 files changed

+289
-84
lines changed

tests/samplers/test_sampler.py

Lines changed: 147 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import random
2-
from typing import Tuple
2+
from typing import Tuple, List
33
from unittest.mock import patch
44

55
import pytest
66
import torch
77
from transformers import GenerationConfig, GenerationMixin
8+
from typing import Optional
89

910
from vllm.model_executor.layers.sampler import Sampler
1011
from vllm.model_executor.utils import set_random_seed
@@ -46,15 +47,13 @@ def _prepare_test(
4647
]
4748

4849

49-
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
50-
@pytest.mark.parametrize("device", CUDA_DEVICES)
51-
def test_sampler_all_greedy(seed: int, device: str):
52-
set_random_seed(seed)
53-
torch.set_default_device(device)
54-
batch_size = random.randint(1, 256)
55-
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
56-
batch_size)
57-
50+
def _do_sample(
51+
batch_size: int,
52+
input_tensor: torch.Tensor,
53+
sampler: MockLogitsSampler,
54+
model_runner: ModelRunner,
55+
sampling_params: SamplingParams,
56+
):
5857
seq_group_metadata_list = []
5958
prompt_lens = []
6059
for i in range(batch_size):
@@ -63,17 +62,31 @@ def test_sampler_all_greedy(seed: int, device: str):
6362
request_id=f"test_{i}",
6463
is_prompt=True,
6564
seq_data={0: SequenceData([1, 2, 3])},
66-
sampling_params=SamplingParams(temperature=0, ),
65+
sampling_params=sampling_params,
6766
block_tables={0: [1]},
6867
))
6968
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
7069

7170
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
7271
prompt_lens,
7372
subquery_lens=prompt_lens)
74-
sampler_output = sampler(embedding=None,
75-
hidden_states=input_tensor,
76-
sampling_metadata=sampling_metadata)
73+
return sampler(embedding=None,
74+
hidden_states=input_tensor,
75+
sampling_metadata=sampling_metadata)
76+
77+
78+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
79+
@pytest.mark.parametrize("device", CUDA_DEVICES)
80+
def test_sampler_all_greedy(seed: int, device: str):
81+
set_random_seed(seed)
82+
torch.set_default_device(device)
83+
batch_size = random.randint(1, 256)
84+
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
85+
batch_size)
86+
87+
sampling_params = SamplingParams(temperature=0)
88+
sampler_output = _do_sample(batch_size, input_tensor, sampler,
89+
model_runner, sampling_params)
7790
expected = torch.argmax(fake_logits, dim=-1)
7891
for i, sequence_output in enumerate(sampler_output):
7992
for nth_output in sequence_output.samples:
@@ -94,35 +107,72 @@ def test_sampler_all_random(seed: int, device: str):
94107
for i in range(batch_size):
95108
fake_logits[i, i] = 1e2
96109

97-
seq_group_metadata_list = []
98-
prompt_lens = []
110+
sampling_params = SamplingParams(
111+
temperature=1.0,
112+
n=random.randint(1, 10),
113+
)
114+
sampler_output = _do_sample(batch_size, input_tensor, sampler,
115+
model_runner, sampling_params)
116+
117+
for i, sequence_output in enumerate(sampler_output):
118+
for nth_output in sequence_output.samples:
119+
assert nth_output.output_token == i
120+
121+
del model_runner
122+
123+
124+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
125+
@pytest.mark.parametrize("device", CUDA_DEVICES)
126+
def test_sampler_all_random_seed(seed: int, device: str):
127+
set_random_seed(seed)
128+
torch.set_default_device(device)
129+
batch_size = random.randint(1, 256)
130+
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
131+
batch_size)
132+
99133
for i in range(batch_size):
100-
seq_group_metadata_list.append(
101-
SequenceGroupMetadata(
102-
request_id=f"test_{i}",
103-
is_prompt=True,
104-
seq_data={0: SequenceData([1, 2, 3])},
105-
sampling_params=SamplingParams(
106-
temperature=1.0,
107-
n=random.randint(1, 10),
108-
),
109-
block_tables={0: [1]},
110-
))
111-
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
134+
fake_logits[i, i] = 1e2
135+
136+
sampling_params = SamplingParams(
137+
temperature=1.0,
138+
n=random.randint(1, 10),
139+
seed=random.randint(0, 10000),
140+
)
141+
sampler_output = _do_sample(batch_size, input_tensor, sampler,
142+
model_runner, sampling_params)
112143

113-
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
114-
prompt_lens,
115-
subquery_lens=prompt_lens)
116-
sampler_output = sampler(embedding=None,
117-
hidden_states=input_tensor,
118-
sampling_metadata=sampling_metadata)
119144
for i, sequence_output in enumerate(sampler_output):
120145
for nth_output in sequence_output.samples:
121146
assert nth_output.output_token == i
122147

123148
del model_runner
124149

125150

151+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
152+
@pytest.mark.parametrize("device", CUDA_DEVICES)
153+
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
154+
set_random_seed(seed)
155+
torch.set_default_device(device)
156+
batch_size = random.randint(1, 256)
157+
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
158+
batch_size)
159+
160+
sampling_params = SamplingParams(
161+
temperature=1.0,
162+
n=random.randint(1, 10),
163+
seed=random.randint(0, 10000),
164+
)
165+
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
166+
model_runner, sampling_params)
167+
168+
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
169+
model_runner, sampling_params)
170+
171+
assert first_sampler_output == second_sampler_output
172+
173+
del model_runner
174+
175+
126176
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
127177
@pytest.mark.parametrize("device", CUDA_DEVICES)
128178
def test_sampler_all_beam(seed: int, device: str):
@@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
131181
batch_size = random.randint(1, 256)
132182
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
133183

134-
seq_group_metadata_list = []
135-
prompt_lens = []
136-
for i in range(batch_size):
137-
seq_group_metadata_list.append(
138-
SequenceGroupMetadata(
139-
request_id=f"test_{i}",
140-
is_prompt=True,
141-
seq_data={0: SequenceData([1, 2, 3])},
142-
sampling_params=SamplingParams(
143-
temperature=0,
144-
best_of=2,
145-
use_beam_search=True,
146-
),
147-
block_tables={0: [1]},
148-
))
149-
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
150-
151-
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
152-
prompt_lens,
153-
subquery_lens=prompt_lens)
154-
sampler(embedding=None,
155-
hidden_states=input_tensor,
156-
sampling_metadata=sampling_metadata)
184+
sampling_params = SamplingParams(
185+
temperature=0,
186+
best_of=2,
187+
use_beam_search=True,
188+
)
189+
_do_sample(batch_size, input_tensor, sampler, model_runner,
190+
sampling_params)
157191
# no assertion here as I am not sure how to determine whether
158192
# the outputs are expected - in other words, this just tests
159193
# whether there are no exceptions in the sampler
@@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
171205
batch_size)
172206

173207
seq_group_metadata_list = []
174-
expected_tokens = []
208+
expected_tokens: List[Optional[List[int]]] = []
175209
prompt_lens = []
176210
for i in range(batch_size):
177-
n = 1
178-
sampling_type = random.randint(0, 2)
211+
expected: Optional[List[int]] = None
212+
sampling_type = random.randint(0, 3)
179213
if sampling_type == 0:
180214
sampling_params = SamplingParams(temperature=0)
181-
elif sampling_type == 1:
215+
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
216+
elif sampling_type in (1, 2):
182217
n = random.randint(1, 10)
183218
sampling_params = SamplingParams(
184219
temperature=random.random() + 0.1,
@@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
187222
n=n,
188223
presence_penalty=random.randint(0, 1),
189224
)
225+
if sampling_type == 2:
226+
sampling_params.seed = random.randint(0, 10000)
227+
else:
228+
for idx in range(n):
229+
fake_logits[i, i + idx] = 1e2
230+
expected = list(range(i, i + n))
190231
else:
191232
sampling_params = SamplingParams(temperature=0,
192233
use_beam_search=True,
193234
best_of=2)
194-
for idx in range(n):
195-
fake_logits[i, i + idx] = 1e2
196-
expected_tokens.append(i + idx)
235+
expected_tokens.append(expected)
197236
seq_group_metadata_list.append(
198237
SequenceGroupMetadata(
199238
request_id=f"test_{i}",
@@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
204243
))
205244
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
206245

207-
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
208-
prompt_lens,
209-
subquery_lens=prompt_lens)
210-
sampler_output = sampler(embedding=None,
211-
hidden_states=input_tensor,
212-
sampling_metadata=sampling_metadata)
213-
for i, sequence_output in enumerate(sampler_output):
214-
if seq_group_metadata_list[i].sampling_params.use_beam_search:
215-
continue
216-
for nth_output in sequence_output.samples:
217-
assert nth_output.output_token in expected_tokens
246+
def test_sampling(model_runner: ModelRunner):
247+
sampling_metadata = model_runner._prepare_sample(
248+
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
249+
sampler_output = sampler(embedding=None,
250+
hidden_states=input_tensor,
251+
sampling_metadata=sampling_metadata)
252+
253+
for i, (sequence_output, metadata) in enumerate(
254+
zip(sampler_output, seq_group_metadata_list)):
255+
if metadata.sampling_params.use_beam_search:
256+
continue
257+
258+
if metadata.sampling_params.seed is not None \
259+
and expected_tokens[i] is None:
260+
# Record seeded random result to compare with results of second invocation
261+
expected_tokens[i] = [
262+
nth_output.output_token
263+
for nth_output in sequence_output.samples
264+
]
265+
continue
266+
267+
for n, nth_output in enumerate(sequence_output.samples):
268+
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
269+
# Ensure exact matches for greedy or random with seed
270+
assert nth_output.output_token == expected_tokens[i][n]
271+
else:
272+
# For non-seeded random check that one of the high-logit tokens were chosen
273+
assert nth_output.output_token in expected_tokens[i]
274+
275+
# Test batch
276+
test_sampling(model_runner)
277+
278+
# Shuffle the batch and resample
279+
target_index = list(range(batch_size))
280+
for list_to_shuffle in (target_index, seq_group_metadata_list,
281+
expected_tokens, prompt_lens):
282+
random.Random(seed).shuffle(list_to_shuffle)
283+
target_index = torch.tensor(target_index)
284+
input_tensor.data = input_tensor.index_select(0, target_index)
285+
fake_logits.data = fake_logits.index_select(0, target_index)
286+
287+
# This time, results of seeded random samples will be compared with the corresponding
288+
# sample in the pre-shuffled batch
289+
test_sampling(model_runner)
218290

219291
del model_runner
220292

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Verify that seeded random sampling is deterministic.
2+
3+
Run `pytest tests/samplers/test_seeded_generate.py --forked`.
4+
"""
5+
import copy
6+
import random
7+
from itertools import combinations
8+
9+
import pytest
10+
11+
from vllm.model_executor.utils import set_random_seed
12+
from vllm import SamplingParams
13+
14+
MODEL = "facebook/opt-125m"
15+
RANDOM_SEEDS = list(range(5))
16+
17+
18+
@pytest.fixture
19+
def vllm_model(vllm_runner):
20+
vllm_model = vllm_runner(MODEL, dtype="half")
21+
yield vllm_model
22+
del vllm_model
23+
24+
25+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
26+
def test_random_sample_with_seed(
27+
vllm_model,
28+
example_prompts,
29+
seed: int,
30+
) -> None:
31+
set_random_seed(seed)
32+
33+
sampling_params = SamplingParams(
34+
# Parameters to ensure sufficient randomness
35+
temperature=2.0,
36+
top_p=min(random.random() + 0.3, 1),
37+
top_k=random.randint(5, 20),
38+
n=random.randint(1, 10),
39+
presence_penalty=random.randint(0, 1),
40+
max_tokens=8,
41+
ignore_eos=True,
42+
)
43+
44+
sampling_params_seed_1 = copy.deepcopy(sampling_params)
45+
sampling_params_seed_1.seed = 100
46+
sampling_params_seed_2 = copy.deepcopy(sampling_params)
47+
sampling_params_seed_2.seed = 200
48+
49+
llm = vllm_model.model
50+
51+
for prompt in example_prompts:
52+
for params in (
53+
sampling_params,
54+
sampling_params_seed_1,
55+
sampling_params_seed_2,
56+
sampling_params,
57+
sampling_params_seed_1,
58+
sampling_params_seed_2,
59+
):
60+
llm._add_request(
61+
prompt=prompt,
62+
prompt_token_ids=None,
63+
sampling_params=params,
64+
)
65+
66+
results = llm._run_engine(use_tqdm=False)
67+
all_outputs = [[out.token_ids for out in output.outputs]
68+
for output in results]
69+
70+
for i in range(0, len(example_prompts), 6):
71+
outputs = all_outputs[i:i + 6]
72+
73+
# verify all non-seeded requests differ
74+
for output_a, output_b in combinations(
75+
(outputs[0], outputs[1], outputs[2], outputs[3]),
76+
2,
77+
):
78+
assert output_a != output_b
79+
80+
# verify requests with the same seed match
81+
assert outputs[1] == outputs[4]
82+
assert outputs[2] == outputs[5]

vllm/core/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
387387
block_tables=block_tables,
388388
lora_request=seq_group.lora_request,
389389
prefix=seq_group.prefix,
390+
state=seq_group.state,
390391
)
391392
seq_group_metadata_list.append(seq_group_metadata)
392393
return seq_group_metadata_list, scheduler_outputs

0 commit comments

Comments
 (0)