Skip to content

Commit cb3b2b9

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
and
Varun Sundar Rabindranath
authored
[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step scheduling (#9038)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent fdf59d3 commit cb3b2b9

File tree

6 files changed

+179
-110
lines changed

6 files changed

+179
-110
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pytest
2+
3+
from tests.conftest import VllmRunner
4+
from tests.core.utils import create_dummy_prompt
5+
from vllm.engine.llm_engine import LLMEngine
6+
from vllm.platforms import current_platform
7+
from vllm.sequence import SequenceGroup
8+
9+
MODEL = "JackFram/llama-160m"
10+
11+
12+
def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup):
13+
scheduler = engine.scheduler[0]
14+
scheduler.add_seq_group(seq_group)
15+
16+
17+
@pytest.mark.parametrize("num_scheduler_steps", [1, 8])
18+
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
19+
@pytest.mark.parametrize("enforce_eager", [False, True])
20+
def test_num_computed_tokens_update(num_scheduler_steps: int,
21+
enable_chunked_prefill: bool,
22+
enforce_eager: bool):
23+
24+
is_multi_step = num_scheduler_steps > 1
25+
is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill
26+
27+
if is_multi_step_chunked_prefill and current_platform.is_rocm():
28+
pytest.skip("Multi-step with Chunked-Prefill does not support "
29+
"rocm_flash_attn backend")
30+
31+
# Make a vllm engine
32+
runner = VllmRunner(model_name=MODEL,
33+
gpu_memory_utilization=0.7,
34+
use_v2_block_manager=True,
35+
num_scheduler_steps=num_scheduler_steps,
36+
enable_chunked_prefill=enable_chunked_prefill,
37+
enforce_eager=enforce_eager)
38+
engine: LLMEngine = runner.model.llm_engine
39+
40+
# In multi-step + chunked-prefill there is no separate single prompt step.
41+
# What is scheduled will run for num_scheduler_steps always.
42+
num_prompt_steps = num_scheduler_steps \
43+
if is_multi_step_chunked_prefill else 1
44+
45+
num_output_tokens_list = [4, 8, 12, 15, 16, 17]
46+
47+
# Create sequence and add to engine
48+
prompt_len = 10
49+
50+
for req_idx, num_output_tokens in enumerate(num_output_tokens_list):
51+
seq, seq_group = create_dummy_prompt(request_id=str(req_idx),
52+
prompt_length=prompt_len,
53+
min_tokens=num_output_tokens,
54+
max_tokens=num_output_tokens)
55+
add_seq_group_to_engine(engine, seq_group)
56+
57+
assert seq.data.get_num_computed_tokens() == 0
58+
59+
for _ in range(num_prompt_steps):
60+
# prompt steps
61+
engine.step()
62+
63+
if not seq.is_finished():
64+
prompt_num_computed_tokens = seq.data.get_num_computed_tokens()
65+
# Test correctness of num_computed_tokens after the prompt steps
66+
assert prompt_num_computed_tokens == \
67+
prompt_len + num_prompt_steps - 1
68+
69+
decode_step_counter = 0
70+
while not seq.is_finished():
71+
# Test correctness of num_computed_tokens after the decode steps
72+
assert seq.data.get_num_computed_tokens(
73+
) == prompt_num_computed_tokens + decode_step_counter
74+
for _ in range(num_scheduler_steps):
75+
# decode step
76+
engine.step()
77+
decode_step_counter += 1
78+
79+
# Test correctness of num_computed_tokens after the sequence finish.
80+
assert seq.data.get_num_computed_tokens(
81+
) == prompt_len + num_output_tokens - 1

tests/core/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def create_dummy_prompt(
1616
use_beam_search: bool = False,
1717
best_of: int = 1,
1818
prompt_tokens: Optional[List[int]] = None,
19+
min_tokens: int = 0,
20+
max_tokens: int = 16,
1921
) -> Tuple[Sequence, SequenceGroup]:
2022
if not block_size:
2123
block_size = prompt_length
@@ -36,7 +38,9 @@ def create_dummy_prompt(
3638
arrival_time=time.time(),
3739
sampling_params=SamplingParams(
3840
use_beam_search=use_beam_search,
39-
best_of=best_of),
41+
best_of=best_of,
42+
max_tokens=max_tokens,
43+
min_tokens=min_tokens),
4044
lora_request=lora_request)
4145

4246
return prompt, seq_group

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,22 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
191191
)
192192
return self._cached_decode_metadata
193193

194-
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
194+
def advance_step(self,
195+
model_input: "ModelInputForGPUWithSamplingMetadata",
195196
sampled_token_ids: Optional[torch.Tensor],
196-
block_size: int, num_seqs: int, num_queries: int):
197+
block_size: int,
198+
num_seqs: int,
199+
num_queries: int,
200+
turn_prefills_into_decodes: bool = False):
197201
"""
198202
Update metadata in-place to advance one decode step.
199203
"""
204+
205+
assert not turn_prefills_into_decodes, \
206+
("Chunked prefill is not supported with rocm_flash_attn yet."
207+
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
208+
"specific parameter.")
209+
200210
# When using cudagraph, the num_seqs is padded to the next captured
201211
# batch sized, but num_queries tracks the actual number of requests in
202212
# the batch. For --enforce-eager mode, num_seqs == num_queries

vllm/engine/llm_engine.py

Lines changed: 66 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,45 @@ def _process_sequence_group_outputs(
962962

963963
return
964964

965+
def _update_num_computed_tokens_for_multi_step_prefill(
966+
self, seq_group: SequenceGroup,
967+
seq_group_meta: SequenceGroupMetadata,
968+
is_first_step_output: Optional[bool]):
969+
"""
970+
This function updates num_computed_tokens for prompt sequences
971+
when Multi-Step is enabled.
972+
973+
seq_group: SequenceGroup to update the num_computed_tokens for.
974+
seq_group_meta: Metadata of the given SequenceGroup.
975+
is_first_step_output: Optional[bool] -
976+
When available, is_first_step_output indicates if the appended
977+
output token is the output of the first-step in multi-step.
978+
A value of None indicates that outputs from all steps in
979+
in multi-step are submitted in a single burst.
980+
"""
981+
982+
assert self.scheduler_config.is_multi_step
983+
984+
if not seq_group_meta.is_prompt:
985+
# num_computed_token updates for multi-step decodes happen after
986+
# the tokens are appended to the sequence.
987+
return
988+
989+
do_update: bool = False
990+
if self.scheduler_config.chunked_prefill_enabled:
991+
# In multi-step + chunked-prefill case, the prompt sequences
992+
# that are scheduled are fully processed in the first step.
993+
do_update = is_first_step_output is None or is_first_step_output
994+
else:
995+
# Normal multi-step decoding case. In this case prompt-sequences
996+
# are actually single-stepped. Always update in this case.
997+
assert seq_group.state.num_steps == 1
998+
do_update = True
999+
1000+
if do_update:
1001+
seq_group.update_num_computed_tokens(
1002+
seq_group_meta.token_chunk_size)
1003+
9651004
def _process_model_outputs(self,
9661005
ctx: SchedulerContext,
9671006
request_id: Optional[str] = None) -> None:
@@ -972,64 +1011,6 @@ def _process_model_outputs(self,
9721011
request_id: If provided, then only this request is going to be processed
9731012
"""
9741013

975-
def update_prefill_num_computed_tokens(
976-
seq_group: SequenceGroup,
977-
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
978-
is_first_step_output: Optional[bool]) -> None:
979-
"""
980-
When multi-step and chunked-prefill are enabled together, the
981-
prefill sequence scheduled for multi-step execution turn into
982-
decodes in the first step itself. This function accounts
983-
for that conversion.
984-
985-
seq_group: SequenceGroup - A prefill seq_group
986-
seq_group_meta: SequenceGroupMetadata - Metadata of the given
987-
prefill seq_group
988-
num_outputs: int - number of output tokens being processed for the
989-
given seq_group
990-
is_first_step_output: Optional[bool] -
991-
If multi-step is enabled and num_outputs is 1, this value
992-
indicates if this outputs belongs to the first step in the
993-
multi-step.
994-
If multi-step is enabled and num_outputs > 1, this value
995-
must be None, as num_outputs > 1 indicates that outputs from
996-
all the steps in multi-step are submitted in a single burst.
997-
When multi-step is disabled, this value is always True.
998-
"""
999-
1000-
assert seq_group_meta.is_prompt
1001-
1002-
token_chunk_size = seq_group_meta.token_chunk_size
1003-
1004-
if num_outputs == 1:
1005-
assert is_first_step_output is not None
1006-
1007-
if seq_group_meta.state.num_steps == 1:
1008-
assert is_first_step_output is True
1009-
seq_group.update_num_computed_tokens(token_chunk_size)
1010-
return
1011-
1012-
# multi-step prefill is only supported when multi-step is
1013-
# enabled with chunked prefill
1014-
assert self.scheduler_config.is_multi_step and \
1015-
self.scheduler_config.chunked_prefill_enabled
1016-
if is_first_step_output is True:
1017-
# This sequence is a prompt during the first step only.
1018-
seq_group.update_num_computed_tokens(token_chunk_size)
1019-
return
1020-
1021-
assert is_first_step_output is None
1022-
1023-
# multi-step prefill is only supported when multi-step is
1024-
# enabled with chunked prefill. Outputs from all the steps are
1025-
# submitted in a single burst.
1026-
assert self.scheduler_config.is_multi_step and \
1027-
self.scheduler_config.chunked_prefill_enabled
1028-
assert num_outputs == seq_group_meta.state.num_steps, \
1029-
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
1030-
# This sequence is a prompt during the first step only.
1031-
seq_group.update_num_computed_tokens(token_chunk_size)
1032-
10331014
now = time.time()
10341015

10351016
if len(ctx.output_queue) == 0:
@@ -1090,7 +1071,7 @@ def update_prefill_num_computed_tokens(
10901071
seq_group_meta = seq_group_metadata_list[i]
10911072
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
10921073

1093-
seq_group = scheduled_seq_group.seq_group
1074+
seq_group: SequenceGroup = scheduled_seq_group.seq_group
10941075

10951076
if seq_group.is_finished():
10961077
finished_before.append(i)
@@ -1101,14 +1082,14 @@ def update_prefill_num_computed_tokens(
11011082
else:
11021083
output = [outputs_by_sequence_group[0][i]]
11031084

1104-
if not is_async and seq_group_meta.is_prompt:
1105-
# Updates for all decodes happen when we actually append the
1106-
# token ids to the seq in process_outputs.
1107-
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
1108-
len(output),
1109-
is_first_step_output)
1110-
elif not is_async:
1111-
seq_group.update_num_computed_tokens(1)
1085+
if not is_async:
1086+
if self.scheduler_config.is_multi_step:
1087+
# Updates happen only if the sequence is prefill
1088+
self._update_num_computed_tokens_for_multi_step_prefill(
1089+
seq_group, seq_group_meta, is_first_step_output)
1090+
else:
1091+
seq_group.update_num_computed_tokens(
1092+
seq_group_meta.token_chunk_size)
11121093

11131094
if outputs:
11141095
for o in outputs:
@@ -1132,16 +1113,8 @@ def update_prefill_num_computed_tokens(
11321113
else:
11331114
self.output_processor.process_prompt_logprob(seq_group, output)
11341115
if seq_group_meta.do_sample:
1135-
output_token_num = self.output_processor.process_outputs(
1116+
self.output_processor.process_outputs(
11361117
seq_group, output, is_async)
1137-
if self.speculative_config:
1138-
# We -1 here because we always
1139-
# (w/o speculative decoding) add the number of
1140-
# computed tokens by one in the decoding phase.
1141-
# Therefore, we remove that one token that
1142-
# is already added.
1143-
seq_group.update_num_computed_tokens(output_token_num -
1144-
1)
11451118

11461119
if seq_group.is_finished():
11471120
finished_now.append(i)
@@ -1250,20 +1223,15 @@ def _advance_to_next_step(
12501223
if seq_group.is_finished():
12511224
continue
12521225

1253-
if seq_group_metadata.is_prompt:
1254-
if self.scheduler_config.is_multi_step and \
1255-
self.scheduler_config.chunked_prefill_enabled:
1256-
# Prompts are scheduled in multi-step only when
1257-
# chunking is enabled. These prompts turn into
1258-
# decodes after the very first step. Therefore,
1259-
# we skip the update to the num_computed_tokens
1260-
# here.
1261-
seq_group.update_num_computed_tokens(1)
1262-
else:
1263-
seq_group.update_num_computed_tokens(
1264-
seq_group_metadata.token_chunk_size)
1226+
if self.scheduler_config.is_multi_step:
1227+
# Updates happen only if the sequence is prefill
1228+
self._update_num_computed_tokens_for_multi_step_prefill(
1229+
seq_group, seq_group_metadata,
1230+
seq_group.state.num_steps == 1)
12651231
else:
1266-
seq_group.update_num_computed_tokens(1)
1232+
seq_group.update_num_computed_tokens(
1233+
seq_group_metadata.token_chunk_size)
1234+
12671235
if seq_group_metadata.do_sample:
12681236
assert len(sequence_group_outputs.samples) == 1, (
12691237
"Async output processor expects a single sample"
@@ -1273,7 +1241,15 @@ def _advance_to_next_step(
12731241

12741242
assert len(seq_group.seqs) == 1
12751243
seq = seq_group.seqs[0]
1276-
seq.append_token_id(sample.output_token, sample.logprobs)
1244+
1245+
if self.scheduler_config.is_multi_step:
1246+
is_prefill_append = seq.data.get_num_uncomputed_tokens(
1247+
) == 0
1248+
seq.append_token_id(sample.output_token, sample.logprobs)
1249+
if not is_prefill_append:
1250+
seq_group.update_num_computed_tokens(1)
1251+
else:
1252+
seq.append_token_id(sample.output_token, sample.logprobs)
12771253

12781254
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
12791255
"""Performs one decoding iteration and returns newly generated results.

vllm/engine/output_processor/interfaces.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Callable, List, Optional
2+
from typing import Callable, List
33

44
from vllm.config import SchedulerConfig
55
from vllm.core.scheduler import Scheduler
@@ -58,14 +58,10 @@ def create_output_processor(
5858
@abstractmethod
5959
def process_outputs(self, sequence_group: SequenceGroup,
6060
outputs: List[SequenceGroupOutput],
61-
is_async: bool) -> Optional[int]:
61+
is_async: bool) -> None:
6262
"""Process new token ids for the sequence group. Handles logic such as
6363
detokenization, stop checking, and freeing/forking sequences in the
6464
scheduler.
65-
66-
Return the number of new tokens generated in the sequence group.
67-
The returned value is optional because it is only used for
68-
speculative decoding mqa scorer.
6965
"""
7066
pass
7167

0 commit comments

Comments
 (0)