Skip to content

Commit c6c8b91

Browse files
njhillAlvant
authored andcommitted
[Core] Add engine option to return only deltas or final output (vllm-project#7381)
1 parent bd0c782 commit c6c8b91

File tree

10 files changed

+371
-137
lines changed

10 files changed

+371
-137
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ steps:
5050
- tests/worker
5151
commands:
5252
- pytest -v -s async_engine # Async Engine
53+
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
5354
- pytest -v -s test_inputs.py
5455
- pytest -v -s multimodal
5556
- pytest -v -s test_utils.py # Utils

tests/async_engine/test_async_llm_engine.py

Lines changed: 147 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import asyncio
2+
import os
3+
import uuid
24
from asyncio import CancelledError
5+
from copy import copy
36
from dataclasses import dataclass
4-
from typing import Optional
7+
from typing import List, Optional
58

69
import pytest
710
import pytest_asyncio
@@ -11,6 +14,7 @@
1114
from vllm.config import ParallelConfig
1215
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
1316
from vllm.outputs import RequestOutput as RealRequestOutput
17+
from vllm.sampling_params import RequestOutputKind
1418

1519
from ..conftest import cleanup
1620
from ..utils import wait_for_gpu_memory_to_clear
@@ -122,8 +126,17 @@ def start_engine():
122126
timeout_s=60,
123127
)
124128

129+
num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
130+
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")
131+
125132
return AsyncLLMEngine.from_engine_args(
126-
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
133+
AsyncEngineArgs(model="facebook/opt-125m",
134+
enforce_eager=True,
135+
num_scheduler_steps=num_scheduler_steps))
136+
137+
138+
def uid() -> str:
139+
return str(uuid.uuid4())
127140

128141

129142
@pytest_asyncio.fixture(scope="module")
@@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
148161
@pytest.mark.asyncio(scope="module")
149162
async def test_asyncio_run(async_engine):
150163

164+
scheduler_config = await async_engine.get_scheduler_config()
165+
num_scheduler_steps = scheduler_config.num_scheduler_steps
166+
151167
async def run(prompt: str):
152168
sampling_params = SamplingParams(
153169
temperature=0,
154170
max_tokens=32,
171+
min_tokens=32,
155172
)
156173

174+
output_count = 0
175+
final_output = None
157176
async for output in async_engine.generate(prompt,
158177
sampling_params,
159-
request_id=prompt):
178+
request_id=uid()):
179+
output_count += 1
160180
final_output = output
161-
return final_output
181+
return final_output, output_count
162182

163183
results = await asyncio.gather(
164184
run("test0"),
165-
run("test1"),
185+
run("test0"),
166186
)
167187
assert len(results) == 2
188+
first, second = results
189+
190+
# remove nondeterministic fields for comparison
191+
first[0].metrics = None
192+
second[0].metrics = None
193+
first[0].request_id = None
194+
second[0].request_id = None
195+
196+
assert str(first) == str(second)
197+
198+
output_count = results[0][1]
199+
if num_scheduler_steps == 1:
200+
assert output_count == 32
201+
else:
202+
assert 1 < output_count < 32
203+
204+
205+
@pytest.mark.asyncio(scope="module")
206+
async def test_output_kinds(async_engine):
207+
"""Test that output_kind works as expected and that
208+
results are equivalent across different kinds."""
209+
210+
scheduler_config = await async_engine.get_scheduler_config()
211+
num_scheduler_steps = scheduler_config.num_scheduler_steps
212+
213+
sampling_params = SamplingParams(
214+
temperature=0,
215+
max_tokens=32,
216+
min_tokens=32,
217+
)
218+
219+
async def run(prompt: str, kind: RequestOutputKind):
220+
params = copy(sampling_params)
221+
params.output_kind = kind
222+
223+
output_count = 0
224+
final_output = None
225+
async for output in async_engine.generate(prompt,
226+
params,
227+
request_id=uid()):
228+
output_count += 1
229+
final_output = output
230+
231+
assert final_output is not None
232+
return (final_output.prompt_token_ids,
233+
final_output.outputs[0].token_ids,
234+
final_output.outputs[0].text, output_count)
235+
236+
async def run_deltas(prompt: str):
237+
params = copy(sampling_params)
238+
params.output_kind = RequestOutputKind.DELTA
239+
240+
prompt_tokens = None
241+
output_tokens: List[int] = []
242+
output_text = ""
243+
output_count = 0
244+
async for output in async_engine.generate(prompt,
245+
params,
246+
request_id=uid()):
247+
token_ids = output.outputs[0].token_ids
248+
text = output.outputs[0].text
249+
250+
# Ensure we get prompt ids iff we haven't yet received output tokens
251+
if output_tokens:
252+
assert 1 <= len(token_ids) <= num_scheduler_steps
253+
assert text
254+
assert not output.prompt_token_ids
255+
else:
256+
assert output.prompt_token_ids
257+
prompt_tokens = output.prompt_token_ids
258+
259+
output_tokens.extend(token_ids)
260+
output_text += text
261+
262+
output_count += 1
263+
return prompt_tokens, output_tokens, output_text, output_count
264+
265+
results = await asyncio.gather(
266+
run("common input prompt", RequestOutputKind.CUMULATIVE),
267+
run("common input prompt", RequestOutputKind.FINAL_ONLY),
268+
run_deltas("common input prompt"))
269+
270+
# Make sure outputs are the same
271+
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
272+
assert len(prompt_set) == 1
273+
274+
text_set = set(text for _, _, text, _ in results)
275+
assert len(text_set) == 1
276+
277+
tokens_set = set(tuple(ids) for _, ids, _, _ in results)
278+
assert len(tokens_set) == 1
279+
280+
cumulative, final, deltas = results
281+
282+
# output message counts
283+
assert cumulative[3] == deltas[3]
284+
285+
if num_scheduler_steps == 1:
286+
assert cumulative[3] == 32
287+
else:
288+
assert 1 < cumulative[3] < 32
289+
290+
assert final[3] == 1
168291

169292

170293
@pytest.mark.asyncio(scope="module")
171294
async def test_cancellation(async_engine):
295+
scheduler_config = await async_engine.get_scheduler_config()
296+
num_scheduler_steps = scheduler_config.num_scheduler_steps
297+
172298
sampling_params = SamplingParams(
173299
temperature=0,
174-
min_tokens=10,
175-
max_tokens=10,
300+
min_tokens=13,
301+
max_tokens=13,
176302
)
177303

304+
stop_at = 5 if num_scheduler_steps == 1 else 1
305+
306+
request_id = uid()
307+
178308
i = 0
179309
with pytest.raises(CancelledError):
180310
async for output in async_engine.generate("test2",
181311
sampling_params,
182-
request_id="test2"):
312+
request_id=request_id):
183313
assert not output.finished
184314
i += 1
185-
if i == 5:
186-
await async_engine.abort("test2")
315+
if i == stop_at:
316+
await async_engine.abort(request_id)
187317

188-
assert i == 5
318+
assert i == stop_at
189319

190320

191321
@pytest.mark.asyncio(scope="module")
192322
async def test_delayed_generator(async_engine):
323+
scheduler_config = await async_engine.get_scheduler_config()
324+
325+
if scheduler_config.num_scheduler_steps != 1:
326+
pytest.skip("no need to test this one with multistep")
327+
193328
sampling_params = SamplingParams(
194329
temperature=0,
195330
min_tokens=10,
196331
max_tokens=10,
197332
)
198333

199-
stream = async_engine.generate("test3",
200-
sampling_params,
201-
request_id="test3")
334+
stream = async_engine.generate("test3", sampling_params, request_id=uid())
202335
i = 0
203336
final_output: Optional[RealRequestOutput] = None
204337
async for output in stream:

vllm/engine/llm_engine.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
RequestOutputFactory)
4040
from vllm.pooling_params import PoolingParams
4141
from vllm.prompt_adapter.request import PromptAdapterRequest
42-
from vllm.sampling_params import SamplingParams
42+
from vllm.sampling_params import RequestOutputKind, SamplingParams
4343
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
4444
Sequence, SequenceGroup, SequenceGroupMetadata,
4545
SequenceStatus)
@@ -225,9 +225,6 @@ def __init__(
225225
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
226226
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
227227
input_registry: InputRegistry = INPUT_REGISTRY,
228-
# To improve performance, only final requests outputs may be required.
229-
# If this set to true, then no intermediate outputs will be returned.
230-
step_return_finished_only: bool = False,
231228
) -> None:
232229
logger.info(
233230
"Initializing an LLM engine (v%s) with config: "
@@ -295,7 +292,6 @@ def __init__(
295292
self.observability_config = observability_config or ObservabilityConfig(
296293
)
297294
self.log_stats = log_stats
298-
self.step_return_finished_only = step_return_finished_only
299295

300296
if not self.model_config.skip_tokenizer_init:
301297
self.tokenizer = self._init_tokenizer()
@@ -1317,7 +1313,7 @@ def _process_model_outputs(self,
13171313
13181314
ctx: The virtual engine context to work on
13191315
request_id: If provided, then only this request is going to be processed
1320-
1316+
13211317
"""
13221318
now = time.time()
13231319

@@ -1422,7 +1418,8 @@ def _process_model_outputs(self,
14221418
seq_group = scheduled_seq_group.seq_group
14231419
seq_group.maybe_set_first_token_time(now)
14241420
request_output = RequestOutputFactory.create(seq_group)
1425-
ctx.request_outputs.append(request_output)
1421+
if request_output:
1422+
ctx.request_outputs.append(request_output)
14261423

14271424
# When we process a single request, we skip it for the next time,
14281425
# and invoke the request output callback (if there was final output)
@@ -1459,14 +1456,19 @@ def _process_model_outputs(self,
14591456

14601457
seq_group = scheduled_seq_group.seq_group
14611458
seq_group.maybe_set_first_token_time(now)
1462-
if (seq_group.is_finished()
1463-
if self.step_return_finished_only else True):
1464-
request_output = RequestOutputFactory.create(seq_group)
1459+
request_output = RequestOutputFactory.create(seq_group)
1460+
if request_output:
14651461
ctx.request_outputs.append(request_output)
14661462

14671463
for seq_group in scheduler_outputs.ignored_seq_groups:
1464+
params = seq_group.sampling_params
1465+
if params is not None and params.output_kind == (
1466+
RequestOutputKind.DELTA) and not seq_group.is_finished():
1467+
continue
1468+
14681469
request_output = RequestOutputFactory.create(seq_group)
1469-
ctx.request_outputs.append(request_output)
1470+
if request_output:
1471+
ctx.request_outputs.append(request_output)
14701472

14711473
# Immediately process request outputs here (if callback is given)
14721474
if (ctx.request_outputs

vllm/entrypoints/llm.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
2020
from vllm.pooling_params import PoolingParams
2121
from vllm.prompt_adapter.request import PromptAdapterRequest
22-
from vllm.sampling_params import SamplingParams
22+
from vllm.sampling_params import RequestOutputKind, SamplingParams
2323
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
2424
get_cached_tokenizer)
2525
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@@ -642,14 +642,12 @@ def _validate_and_add_requests(
642642
raise ValueError("The lengths of prompts and lora_request "
643643
"must be the same.")
644644

645-
if isinstance(params, list):
646-
params = [
647-
self._add_guided_processor(param, guided_options)
648-
if isinstance(param, SamplingParams) else param
649-
for param in params
650-
]
651-
elif isinstance(params, SamplingParams):
652-
params = self._add_guided_processor(params, guided_options)
645+
for sp in params if isinstance(params, list) else (params, ):
646+
if isinstance(sp, SamplingParams):
647+
self._add_guided_processor(sp, guided_options)
648+
649+
# We only care about the final output
650+
sp.output_kind = RequestOutputKind.FINAL_ONLY
653651

654652
# Add requests to the engine.
655653
for i, request_inputs in enumerate(inputs):
@@ -709,9 +707,6 @@ def _run_engine(
709707
f"output: {0:.2f} toks/s"),
710708
)
711709

712-
# In the loop below, only finished outputs are used
713-
self.llm_engine.step_return_finished_only = True
714-
715710
# Run the engine.
716711
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
717712
total_in_toks = 0
@@ -724,6 +719,7 @@ def _run_engine(
724719
if use_tqdm:
725720
if isinstance(output, RequestOutput):
726721
# Calculate tokens only for RequestOutput
722+
assert output.prompt_token_ids is not None
727723
total_in_toks += len(output.prompt_token_ids)
728724
in_spd = total_in_toks / pbar.format_dict["elapsed"]
729725
total_out_toks += sum(
@@ -735,9 +731,6 @@ def _run_engine(
735731
f"output: {out_spd:.2f} toks/s")
736732
pbar.update(1)
737733

738-
# Restore original behavior
739-
self.llm_engine.step_return_finished_only = False
740-
741734
if use_tqdm:
742735
pbar.close()
743736
# Sort the outputs by request ID.

vllm/entrypoints/openai/protocol.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
1313
from vllm.entrypoints.openai.logits_processors import get_logits_processors
1414
from vllm.pooling_params import PoolingParams
15-
from vllm.sampling_params import LogitsProcessor, SamplingParams
15+
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
16+
SamplingParams)
1617
from vllm.sequence import Logprob
1718
from vllm.transformers_utils.tokenizer import AnyTokenizer
1819
from vllm.utils import random_uuid
@@ -316,6 +317,8 @@ def to_sampling_params(
316317
length_penalty=self.length_penalty,
317318
logits_processors=logits_processors,
318319
truncate_prompt_tokens=self.truncate_prompt_tokens,
320+
output_kind=RequestOutputKind.DELTA if self.stream \
321+
else RequestOutputKind.FINAL_ONLY,
319322
)
320323

321324
@model_validator(mode="before")
@@ -559,6 +562,8 @@ def to_sampling_params(
559562
length_penalty=self.length_penalty,
560563
logits_processors=logits_processors,
561564
truncate_prompt_tokens=self.truncate_prompt_tokens,
565+
output_kind=RequestOutputKind.DELTA if self.stream \
566+
else RequestOutputKind.FINAL_ONLY,
562567
)
563568

564569
@model_validator(mode="before")

0 commit comments

Comments
 (0)