Skip to content

Commit 1b925bb

Browse files
author
Varun Sundar Rabindranath
committed
Bugfix : Multi-step + PP
1 parent c2ec430 commit 1b925bb

File tree

5 files changed

+130
-15
lines changed

5 files changed

+130
-15
lines changed

tests/multi_step/test_correctness_async_llm.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,85 @@ async def test_multi_step(
142142
name_0="hf",
143143
name_1="vllm",
144144
)
145+
146+
147+
@pytest.mark.parametrize(("tp_size, pp_size"), [
148+
(1, 2),
149+
])
150+
@pytest.mark.asyncio
151+
async def test_multi_step_pp_smoke(
152+
tp_size: int,
153+
pp_size: int,
154+
monkeypatch,
155+
) -> None:
156+
"""
157+
Smoke test for the vLLM engine with multi-step scheduling in an
158+
OpenAI-protocol client/server environment.
159+
160+
This tests compares the outputs between multi-step scheduling and
161+
single-step scheduling. Notably, this test lets the engines generate
162+
more tokens (default is 5) and test for an exact match over all the
163+
tokens.
164+
165+
Args:
166+
tp_size: degree of tensor-parallelism
167+
pp_size: degree of pipeline-parallelism
168+
eager_mode
169+
"""
170+
171+
model = "JackFram/llama-160m"
172+
num_scheduler_steps = 8
173+
attention_backend = "FLASH_ATTN"
174+
max_num_seqs = 3
175+
176+
override_backend_env_variable(monkeypatch, attention_backend)
177+
178+
# Prompt from the ShareGPT dataset
179+
prompts = [
180+
"Do you know the book Traction by Gino Wickman",
181+
"Do you know the book Traction by Gino Wickman",
182+
"Do you know the book Traction by Gino Wickman",
183+
"Do you know the book Traction by Gino Wickman",
184+
]
185+
# Use varying max_tokens to introduce scheduling randomness.
186+
max_tokens = [10 * i for i in range(1, len(prompts) + 1)]
187+
assert len(prompts) == len(max_tokens)
188+
189+
test_args = [
190+
"--tensor-parallel-size",
191+
str(tp_size), "--pipeline-parallel-size",
192+
str(pp_size), "--max-num-seqs",
193+
str(max_num_seqs)
194+
]
195+
196+
server_args = DEFAULT_SERVER_ARGS + test_args
197+
ms_server_args = DEFAULT_SERVER_ARGS + \
198+
["--num-scheduler-steps", f"{num_scheduler_steps}"] + \
199+
test_args
200+
201+
# Spin up client/server & issue completion API requests.
202+
# Default `max_wait_seconds` is 240 but was empirically
203+
# was raised 3x to 720 *just for this test* due to
204+
# observed timeouts in GHA CI
205+
ref_completions = await completions_with_server_args(
206+
prompts=prompts,
207+
model_name=model,
208+
server_cli_args=server_args,
209+
num_logprobs=None,
210+
max_wait_seconds=5 * 240,
211+
max_tokens=max_tokens)
212+
213+
test_completions = await completions_with_server_args(
214+
prompts=prompts,
215+
model_name=model,
216+
server_cli_args=ms_server_args,
217+
num_logprobs=None,
218+
max_wait_seconds=5 * 240,
219+
max_tokens=max_tokens)
220+
221+
# Assert multi-step scheduling produces identical tokens
222+
# to single-step scheduling.
223+
ref_generations = get_client_text_generations(ref_completions)
224+
test_generations = get_client_text_generations(test_completions)
225+
226+
assert ref_generations == test_generations

tests/utils.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
import os
34
import signal
@@ -7,7 +8,7 @@
78
import warnings
89
from contextlib import contextmanager
910
from pathlib import Path
10-
from typing import Any, Callable, Dict, List, Optional
11+
from typing import Any, Callable, Dict, List, Optional, Union
1112

1213
import openai
1314
import pytest
@@ -476,7 +477,8 @@ async def completions_with_server_args(
476477
server_cli_args: List[str],
477478
num_logprobs: Optional[int],
478479
max_wait_seconds: int = 240,
479-
) -> Completion:
480+
max_tokens: Union[int, list] = 5,
481+
) -> List[Completion]:
480482
'''Construct a remote OpenAI server, obtain an async client to the
481483
server & invoke the completions API to obtain completions.
482484
@@ -487,37 +489,49 @@ async def completions_with_server_args(
487489
num_logprobs: Number of logprobs to report (or `None`)
488490
max_wait_seconds: timeout interval for bringing up server.
489491
Default: 240sec
492+
max_tokens: max_tokens value for each of the given input prompts.
493+
if only one max_token value is given, the same value is used
494+
for all the prompts.
490495
491496
Returns:
492497
OpenAI Completion instance
493498
'''
494499

500+
if isinstance(max_tokens, int):
501+
max_tokens = [max_tokens] * len(prompts)
502+
503+
assert len(max_tokens) == len(prompts)
504+
495505
outputs = None
496506
max_wait_seconds = 240 * 3 # 240 is default
497507
with RemoteOpenAIServer(model_name,
498508
server_cli_args,
499509
max_wait_seconds=max_wait_seconds) as server:
500510
client = server.get_async_client()
501-
outputs = await client.completions.create(model=model_name,
502-
prompt=prompts,
503-
temperature=0,
504-
stream=False,
505-
max_tokens=5,
506-
logprobs=num_logprobs)
511+
outputs = [ client.completions.create(model=model_name,
512+
prompt=[p],
513+
temperature=0,
514+
stream=False,
515+
max_tokens=max_tok,
516+
logprobs=num_logprobs) \
517+
for p, max_tok in zip(prompts, max_tokens) ]
518+
outputs = await asyncio.gather(*outputs)
519+
507520
assert outputs is not None, "Completion API call failed."
508521

509522
return outputs
510523

511524

512-
def get_client_text_generations(completions: Completion) -> List[str]:
525+
def get_client_text_generations(completions: List[Completion]) -> List[str]:
513526
'''Extract generated tokens from the output of a
514527
request made to an Open-AI-protocol completions endpoint.
515528
'''
516-
return [x.text for x in completions.choices]
529+
assert all([len(x.choices) == 1 for x in completions])
530+
return [x.choices[0].text for x in completions]
517531

518532

519533
def get_client_text_logprob_generations(
520-
completions: Completion) -> List[TextTextLogprobs]:
534+
completions: List[Completion]) -> List[TextTextLogprobs]:
521535
'''Operates on the output of a request made to an Open-AI-protocol
522536
completions endpoint; obtains top-rank logprobs for each token in
523537
each :class:`SequenceGroup`
@@ -526,4 +540,4 @@ def get_client_text_logprob_generations(
526540
text = ''.join(text_generations)
527541
return [(text_generations, text,
528542
(None if x.logprobs is None else x.logprobs.top_logprobs))
529-
for x in completions.choices]
543+
for completion in completions for x in completion.choices]

vllm/engine/output_processor/multi_step.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def process_outputs(self,
9797
assert len(seqs) == 1, (
9898
"Beam search not supported in multi-step decoding.")
9999
seq = seqs[0]
100+
seq_id = seq.seq_id
101+
assert all(
102+
[seq_id == output.samples[0].parent_seq_id for output in outputs])
100103

101104
if is_async:
102105
# Async case: We process tokens one by one. Here, we know the token

vllm/worker/model_runner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,16 @@ def __init__(
10071007

10081008
# Used to cache python objects
10091009
self.inter_data_cache: Dict[int, PyObjectCache] = {}
1010+
1011+
# Using the PythonizationCache in Pipeline-Parallel clobbers the
1012+
# SequenceGroupToSample object. In Pipeline-Parallel, we have
1013+
# more than 1 Scheduler, resulting in a potential back-to-back
1014+
# prepare_model_inputs() call. This clobbers the cached
1015+
# SequenceGroupToSample objects, as we reset the cache during
1016+
# every prepare_model_inputs() call.
10101017
self.sampling_metadata_cache: SamplingMetadataCache = \
1011-
SamplingMetadataCache()
1018+
SamplingMetadataCache() \
1019+
if self.parallel_config.pipeline_parallel_size == 1 else None
10121020

10131021
def load_model(self) -> None:
10141022
logger.info("Starting to load model %s...", self.model_config.model)

vllm/worker/multi_step_model_runner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,14 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
326326
self.is_multi_step = self.scheduler_config.is_multi_step
327327
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
328328

329-
self.pythonization_cache = PythonizationCache()
329+
# Using the PythonizationCache in Pipeline-Parallel clobbers the
330+
# SequenceOutput and CompletionSequenceGroupOutput object.
331+
# When cache-reset happens at the last step of a multi-step
332+
# execution, there may be other on-going single-step/multi-step
333+
# executions. The current caching implementation does not check
334+
# for this.
335+
self.pythonization_cache = PythonizationCache() \
336+
if self.parallel_config.pipeline_parallel_size == 1 else None
330337

331338
@functools.cached_property
332339
def _copy_stream(self):
@@ -577,7 +584,8 @@ def execute_model(
577584
if model_input.is_last_step:
578585
outputs = self._final_process_outputs(
579586
model_input, model_input.base_output_proc_callback)
580-
self.pythonization_cache.reset()
587+
if self.pythonization_cache:
588+
self.pythonization_cache.reset()
581589
return outputs
582590

583591
# should be [SamplerOutput]

0 commit comments

Comments
 (0)