1
1
import asyncio
2
+ import os
3
+ import uuid
2
4
from asyncio import CancelledError
5
+ from copy import copy
3
6
from dataclasses import dataclass
4
- from typing import Optional
7
+ from typing import List , Optional
5
8
6
9
import pytest
7
10
import pytest_asyncio
11
14
from vllm .config import ParallelConfig
12
15
from vllm .engine .async_llm_engine import AsyncEngineArgs , AsyncLLMEngine
13
16
from vllm .outputs import RequestOutput as RealRequestOutput
17
+ from vllm .sampling_params import RequestOutputKind
14
18
15
19
from ..conftest import cleanup
16
20
from ..utils import wait_for_gpu_memory_to_clear
@@ -122,8 +126,17 @@ def start_engine():
122
126
timeout_s = 60 ,
123
127
)
124
128
129
+ num_scheduler_steps = int (os .getenv ("NUM_SCHEDULER_STEPS" , "1" ))
130
+ print (f"Starting engine with num_scheduler_steps={ num_scheduler_steps } " )
131
+
125
132
return AsyncLLMEngine .from_engine_args (
126
- AsyncEngineArgs (model = "facebook/opt-125m" , enforce_eager = True ))
133
+ AsyncEngineArgs (model = "facebook/opt-125m" ,
134
+ enforce_eager = True ,
135
+ num_scheduler_steps = num_scheduler_steps ))
136
+
137
+
138
+ def uid () -> str :
139
+ return str (uuid .uuid4 ())
127
140
128
141
129
142
@pytest_asyncio .fixture (scope = "module" )
@@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
148
161
@pytest .mark .asyncio (scope = "module" )
149
162
async def test_asyncio_run (async_engine ):
150
163
164
+ scheduler_config = await async_engine .get_scheduler_config ()
165
+ num_scheduler_steps = scheduler_config .num_scheduler_steps
166
+
151
167
async def run (prompt : str ):
152
168
sampling_params = SamplingParams (
153
169
temperature = 0 ,
154
170
max_tokens = 32 ,
171
+ min_tokens = 32 ,
155
172
)
156
173
174
+ output_count = 0
175
+ final_output = None
157
176
async for output in async_engine .generate (prompt ,
158
177
sampling_params ,
159
- request_id = prompt ):
178
+ request_id = uid ()):
179
+ output_count += 1
160
180
final_output = output
161
- return final_output
181
+ return final_output , output_count
162
182
163
183
results = await asyncio .gather (
164
184
run ("test0" ),
165
- run ("test1 " ),
185
+ run ("test0 " ),
166
186
)
167
187
assert len (results ) == 2
188
+ first , second = results
189
+
190
+ # remove nondeterministic fields for comparison
191
+ first [0 ].metrics = None
192
+ second [0 ].metrics = None
193
+ first [0 ].request_id = None
194
+ second [0 ].request_id = None
195
+
196
+ assert str (first ) == str (second )
197
+
198
+ output_count = results [0 ][1 ]
199
+ if num_scheduler_steps == 1 :
200
+ assert output_count == 32
201
+ else :
202
+ assert 1 < output_count < 32
203
+
204
+
205
+ @pytest .mark .asyncio (scope = "module" )
206
+ async def test_output_kinds (async_engine ):
207
+ """Test that output_kind works as expected and that
208
+ results are equivalent across different kinds."""
209
+
210
+ scheduler_config = await async_engine .get_scheduler_config ()
211
+ num_scheduler_steps = scheduler_config .num_scheduler_steps
212
+
213
+ sampling_params = SamplingParams (
214
+ temperature = 0 ,
215
+ max_tokens = 32 ,
216
+ min_tokens = 32 ,
217
+ )
218
+
219
+ async def run (prompt : str , kind : RequestOutputKind ):
220
+ params = copy (sampling_params )
221
+ params .output_kind = kind
222
+
223
+ output_count = 0
224
+ final_output = None
225
+ async for output in async_engine .generate (prompt ,
226
+ params ,
227
+ request_id = uid ()):
228
+ output_count += 1
229
+ final_output = output
230
+
231
+ assert final_output is not None
232
+ return (final_output .prompt_token_ids ,
233
+ final_output .outputs [0 ].token_ids ,
234
+ final_output .outputs [0 ].text , output_count )
235
+
236
+ async def run_deltas (prompt : str ):
237
+ params = copy (sampling_params )
238
+ params .output_kind = RequestOutputKind .DELTA
239
+
240
+ prompt_tokens = None
241
+ output_tokens : List [int ] = []
242
+ output_text = ""
243
+ output_count = 0
244
+ async for output in async_engine .generate (prompt ,
245
+ params ,
246
+ request_id = uid ()):
247
+ token_ids = output .outputs [0 ].token_ids
248
+ text = output .outputs [0 ].text
249
+
250
+ # Ensure we get prompt ids iff we haven't yet received output tokens
251
+ if output_tokens :
252
+ assert 1 <= len (token_ids ) <= num_scheduler_steps
253
+ assert text
254
+ assert not output .prompt_token_ids
255
+ else :
256
+ assert output .prompt_token_ids
257
+ prompt_tokens = output .prompt_token_ids
258
+
259
+ output_tokens .extend (token_ids )
260
+ output_text += text
261
+
262
+ output_count += 1
263
+ return prompt_tokens , output_tokens , output_text , output_count
264
+
265
+ results = await asyncio .gather (
266
+ run ("common input prompt" , RequestOutputKind .CUMULATIVE ),
267
+ run ("common input prompt" , RequestOutputKind .FINAL_ONLY ),
268
+ run_deltas ("common input prompt" ))
269
+
270
+ # Make sure outputs are the same
271
+ prompt_set = set (tuple (prompt_ids ) for prompt_ids , _ , _ , _ in results )
272
+ assert len (prompt_set ) == 1
273
+
274
+ text_set = set (text for _ , _ , text , _ in results )
275
+ assert len (text_set ) == 1
276
+
277
+ tokens_set = set (tuple (ids ) for _ , ids , _ , _ in results )
278
+ assert len (tokens_set ) == 1
279
+
280
+ cumulative , final , deltas = results
281
+
282
+ # output message counts
283
+ assert cumulative [3 ] == deltas [3 ]
284
+
285
+ if num_scheduler_steps == 1 :
286
+ assert cumulative [3 ] == 32
287
+ else :
288
+ assert 1 < cumulative [3 ] < 32
289
+
290
+ assert final [3 ] == 1
168
291
169
292
170
293
@pytest .mark .asyncio (scope = "module" )
171
294
async def test_cancellation (async_engine ):
295
+ scheduler_config = await async_engine .get_scheduler_config ()
296
+ num_scheduler_steps = scheduler_config .num_scheduler_steps
297
+
172
298
sampling_params = SamplingParams (
173
299
temperature = 0 ,
174
- min_tokens = 10 ,
175
- max_tokens = 10 ,
300
+ min_tokens = 13 ,
301
+ max_tokens = 13 ,
176
302
)
177
303
304
+ stop_at = 5 if num_scheduler_steps == 1 else 1
305
+
306
+ request_id = uid ()
307
+
178
308
i = 0
179
309
with pytest .raises (CancelledError ):
180
310
async for output in async_engine .generate ("test2" ,
181
311
sampling_params ,
182
- request_id = "test2" ):
312
+ request_id = request_id ):
183
313
assert not output .finished
184
314
i += 1
185
- if i == 5 :
186
- await async_engine .abort ("test2" )
315
+ if i == stop_at :
316
+ await async_engine .abort (request_id )
187
317
188
- assert i == 5
318
+ assert i == stop_at
189
319
190
320
191
321
@pytest .mark .asyncio (scope = "module" )
192
322
async def test_delayed_generator (async_engine ):
323
+ scheduler_config = await async_engine .get_scheduler_config ()
324
+
325
+ if scheduler_config .num_scheduler_steps != 1 :
326
+ pytest .skip ("no need to test this one with multistep" )
327
+
193
328
sampling_params = SamplingParams (
194
329
temperature = 0 ,
195
330
min_tokens = 10 ,
196
331
max_tokens = 10 ,
197
332
)
198
333
199
- stream = async_engine .generate ("test3" ,
200
- sampling_params ,
201
- request_id = "test3" )
334
+ stream = async_engine .generate ("test3" , sampling_params , request_id = uid ())
202
335
i = 0
203
336
final_output : Optional [RealRequestOutput ] = None
204
337
async for output in stream :
0 commit comments