1
1
import random
2
- from typing import Tuple
2
+ from typing import Tuple , List
3
3
from unittest .mock import patch
4
4
5
5
import pytest
6
6
import torch
7
7
from transformers import GenerationConfig , GenerationMixin
8
+ from typing import Optional
8
9
9
10
from vllm .model_executor .layers .sampler import Sampler
10
11
from vllm .model_executor .utils import set_random_seed
@@ -46,15 +47,13 @@ def _prepare_test(
46
47
]
47
48
48
49
49
- @pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
50
- @pytest .mark .parametrize ("device" , CUDA_DEVICES )
51
- def test_sampler_all_greedy (seed : int , device : str ):
52
- set_random_seed (seed )
53
- torch .set_default_device (device )
54
- batch_size = random .randint (1 , 256 )
55
- input_tensor , fake_logits , sampler , model_runner = _prepare_test (
56
- batch_size )
57
-
50
+ def _do_sample (
51
+ batch_size : int ,
52
+ input_tensor : torch .Tensor ,
53
+ sampler : MockLogitsSampler ,
54
+ model_runner : ModelRunner ,
55
+ sampling_params : SamplingParams ,
56
+ ):
58
57
seq_group_metadata_list = []
59
58
prompt_lens = []
60
59
for i in range (batch_size ):
@@ -63,17 +62,31 @@ def test_sampler_all_greedy(seed: int, device: str):
63
62
request_id = f"test_{ i } " ,
64
63
is_prompt = True ,
65
64
seq_data = {0 : SequenceData ([1 , 2 , 3 ])},
66
- sampling_params = SamplingParams ( temperature = 0 , ) ,
65
+ sampling_params = sampling_params ,
67
66
block_tables = {0 : [1 ]},
68
67
))
69
68
prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
70
69
71
70
sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
72
71
prompt_lens ,
73
72
subquery_lens = prompt_lens )
74
- sampler_output = sampler (embedding = None ,
75
- hidden_states = input_tensor ,
76
- sampling_metadata = sampling_metadata )
73
+ return sampler (embedding = None ,
74
+ hidden_states = input_tensor ,
75
+ sampling_metadata = sampling_metadata )
76
+
77
+
78
+ @pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
79
+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
80
+ def test_sampler_all_greedy (seed : int , device : str ):
81
+ set_random_seed (seed )
82
+ torch .set_default_device (device )
83
+ batch_size = random .randint (1 , 256 )
84
+ input_tensor , fake_logits , sampler , model_runner = _prepare_test (
85
+ batch_size )
86
+
87
+ sampling_params = SamplingParams (temperature = 0 )
88
+ sampler_output = _do_sample (batch_size , input_tensor , sampler ,
89
+ model_runner , sampling_params )
77
90
expected = torch .argmax (fake_logits , dim = - 1 )
78
91
for i , sequence_output in enumerate (sampler_output ):
79
92
for nth_output in sequence_output .samples :
@@ -94,35 +107,72 @@ def test_sampler_all_random(seed: int, device: str):
94
107
for i in range (batch_size ):
95
108
fake_logits [i , i ] = 1e2
96
109
97
- seq_group_metadata_list = []
98
- prompt_lens = []
110
+ sampling_params = SamplingParams (
111
+ temperature = 1.0 ,
112
+ n = random .randint (1 , 10 ),
113
+ )
114
+ sampler_output = _do_sample (batch_size , input_tensor , sampler ,
115
+ model_runner , sampling_params )
116
+
117
+ for i , sequence_output in enumerate (sampler_output ):
118
+ for nth_output in sequence_output .samples :
119
+ assert nth_output .output_token == i
120
+
121
+ del model_runner
122
+
123
+
124
+ @pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
125
+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
126
+ def test_sampler_all_random_seed (seed : int , device : str ):
127
+ set_random_seed (seed )
128
+ torch .set_default_device (device )
129
+ batch_size = random .randint (1 , 256 )
130
+ input_tensor , fake_logits , sampler , model_runner = _prepare_test (
131
+ batch_size )
132
+
99
133
for i in range (batch_size ):
100
- seq_group_metadata_list .append (
101
- SequenceGroupMetadata (
102
- request_id = f"test_{ i } " ,
103
- is_prompt = True ,
104
- seq_data = {0 : SequenceData ([1 , 2 , 3 ])},
105
- sampling_params = SamplingParams (
106
- temperature = 1.0 ,
107
- n = random .randint (1 , 10 ),
108
- ),
109
- block_tables = {0 : [1 ]},
110
- ))
111
- prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
134
+ fake_logits [i , i ] = 1e2
135
+
136
+ sampling_params = SamplingParams (
137
+ temperature = 1.0 ,
138
+ n = random .randint (1 , 10 ),
139
+ seed = random .randint (0 , 10000 ),
140
+ )
141
+ sampler_output = _do_sample (batch_size , input_tensor , sampler ,
142
+ model_runner , sampling_params )
112
143
113
- sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
114
- prompt_lens ,
115
- subquery_lens = prompt_lens )
116
- sampler_output = sampler (embedding = None ,
117
- hidden_states = input_tensor ,
118
- sampling_metadata = sampling_metadata )
119
144
for i , sequence_output in enumerate (sampler_output ):
120
145
for nth_output in sequence_output .samples :
121
146
assert nth_output .output_token == i
122
147
123
148
del model_runner
124
149
125
150
151
+ @pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
152
+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
153
+ def test_sampler_all_random_seed_deterministic (seed : int , device : str ):
154
+ set_random_seed (seed )
155
+ torch .set_default_device (device )
156
+ batch_size = random .randint (1 , 256 )
157
+ input_tensor , fake_logits , sampler , model_runner = _prepare_test (
158
+ batch_size )
159
+
160
+ sampling_params = SamplingParams (
161
+ temperature = 1.0 ,
162
+ n = random .randint (1 , 10 ),
163
+ seed = random .randint (0 , 10000 ),
164
+ )
165
+ first_sampler_output = _do_sample (batch_size , input_tensor , sampler ,
166
+ model_runner , sampling_params )
167
+
168
+ second_sampler_output = _do_sample (batch_size , input_tensor , sampler ,
169
+ model_runner , sampling_params )
170
+
171
+ assert first_sampler_output == second_sampler_output
172
+
173
+ del model_runner
174
+
175
+
126
176
@pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
127
177
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
128
178
def test_sampler_all_beam (seed : int , device : str ):
@@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
131
181
batch_size = random .randint (1 , 256 )
132
182
input_tensor , _ , sampler , model_runner = _prepare_test (batch_size )
133
183
134
- seq_group_metadata_list = []
135
- prompt_lens = []
136
- for i in range (batch_size ):
137
- seq_group_metadata_list .append (
138
- SequenceGroupMetadata (
139
- request_id = f"test_{ i } " ,
140
- is_prompt = True ,
141
- seq_data = {0 : SequenceData ([1 , 2 , 3 ])},
142
- sampling_params = SamplingParams (
143
- temperature = 0 ,
144
- best_of = 2 ,
145
- use_beam_search = True ,
146
- ),
147
- block_tables = {0 : [1 ]},
148
- ))
149
- prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
150
-
151
- sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
152
- prompt_lens ,
153
- subquery_lens = prompt_lens )
154
- sampler (embedding = None ,
155
- hidden_states = input_tensor ,
156
- sampling_metadata = sampling_metadata )
184
+ sampling_params = SamplingParams (
185
+ temperature = 0 ,
186
+ best_of = 2 ,
187
+ use_beam_search = True ,
188
+ )
189
+ _do_sample (batch_size , input_tensor , sampler , model_runner ,
190
+ sampling_params )
157
191
# no assertion here as I am not sure how to determine whether
158
192
# the outputs are expected - in other words, this just tests
159
193
# whether there are no exceptions in the sampler
@@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
171
205
batch_size )
172
206
173
207
seq_group_metadata_list = []
174
- expected_tokens = []
208
+ expected_tokens : List [ Optional [ List [ int ]]] = []
175
209
prompt_lens = []
176
210
for i in range (batch_size ):
177
- n = 1
178
- sampling_type = random .randint (0 , 2 )
211
+ expected : Optional [ List [ int ]] = None
212
+ sampling_type = random .randint (0 , 3 )
179
213
if sampling_type == 0 :
180
214
sampling_params = SamplingParams (temperature = 0 )
181
- elif sampling_type == 1 :
215
+ expected = [torch .argmax (fake_logits [i ], dim = - 1 ).item ()]
216
+ elif sampling_type in (1 , 2 ):
182
217
n = random .randint (1 , 10 )
183
218
sampling_params = SamplingParams (
184
219
temperature = random .random () + 0.1 ,
@@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
187
222
n = n ,
188
223
presence_penalty = random .randint (0 , 1 ),
189
224
)
225
+ if sampling_type == 2 :
226
+ sampling_params .seed = random .randint (0 , 10000 )
227
+ else :
228
+ for idx in range (n ):
229
+ fake_logits [i , i + idx ] = 1e2
230
+ expected = list (range (i , i + n ))
190
231
else :
191
232
sampling_params = SamplingParams (temperature = 0 ,
192
233
use_beam_search = True ,
193
234
best_of = 2 )
194
- for idx in range (n ):
195
- fake_logits [i , i + idx ] = 1e2
196
- expected_tokens .append (i + idx )
235
+ expected_tokens .append (expected )
197
236
seq_group_metadata_list .append (
198
237
SequenceGroupMetadata (
199
238
request_id = f"test_{ i } " ,
@@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
204
243
))
205
244
prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
206
245
207
- sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
208
- prompt_lens ,
209
- subquery_lens = prompt_lens )
210
- sampler_output = sampler (embedding = None ,
211
- hidden_states = input_tensor ,
212
- sampling_metadata = sampling_metadata )
213
- for i , sequence_output in enumerate (sampler_output ):
214
- if seq_group_metadata_list [i ].sampling_params .use_beam_search :
215
- continue
216
- for nth_output in sequence_output .samples :
217
- assert nth_output .output_token in expected_tokens
246
+ def test_sampling (model_runner : ModelRunner ):
247
+ sampling_metadata = model_runner ._prepare_sample (
248
+ seq_group_metadata_list , prompt_lens , subquery_lens = prompt_lens )
249
+ sampler_output = sampler (embedding = None ,
250
+ hidden_states = input_tensor ,
251
+ sampling_metadata = sampling_metadata )
252
+
253
+ for i , (sequence_output , metadata ) in enumerate (
254
+ zip (sampler_output , seq_group_metadata_list )):
255
+ if metadata .sampling_params .use_beam_search :
256
+ continue
257
+
258
+ if metadata .sampling_params .seed is not None \
259
+ and expected_tokens [i ] is None :
260
+ # Record seeded random result to compare with results of second invocation
261
+ expected_tokens [i ] = [
262
+ nth_output .output_token
263
+ for nth_output in sequence_output .samples
264
+ ]
265
+ continue
266
+
267
+ for n , nth_output in enumerate (sequence_output .samples ):
268
+ if metadata .sampling_params .temperature == 0 or metadata .sampling_params .seed is not None :
269
+ # Ensure exact matches for greedy or random with seed
270
+ assert nth_output .output_token == expected_tokens [i ][n ]
271
+ else :
272
+ # For non-seeded random check that one of the high-logit tokens were chosen
273
+ assert nth_output .output_token in expected_tokens [i ]
274
+
275
+ # Test batch
276
+ test_sampling (model_runner )
277
+
278
+ # Shuffle the batch and resample
279
+ target_index = list (range (batch_size ))
280
+ for list_to_shuffle in (target_index , seq_group_metadata_list ,
281
+ expected_tokens , prompt_lens ):
282
+ random .Random (seed ).shuffle (list_to_shuffle )
283
+ target_index = torch .tensor (target_index )
284
+ input_tensor .data = input_tensor .index_select (0 , target_index )
285
+ fake_logits .data = fake_logits .index_select (0 , target_index )
286
+
287
+ # This time, results of seeded random samples will be compared with the corresponding
288
+ # sample in the pre-shuffled batch
289
+ test_sampling (model_runner )
218
290
219
291
del model_runner
220
292
0 commit comments