Skip to content

Commit c2ec430

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
and
Varun Sundar Rabindranath
authored
[Core] Multi-Step + Single Step Prefills via Chunked Prefill code path (#8378)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent c5d5535 commit c2ec430

19 files changed

+514
-109
lines changed

csrc/prepare_inputs/advance_step.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel(
5252
slot_mapping_ptr[cur_query_id] = slot_num;
5353
}
5454

55-
inline void verify_tensor(std::string const& name, torch::Tensor& t,
55+
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
5656
int64_t const size_0, int64_t const size_1,
5757
c10::ScalarType const type) {
5858
bool size_0_cond = true;

tests/multi_step/test_correctness_async_llm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
@pytest.mark.parametrize("num_logprobs", [5])
3838
@pytest.mark.parametrize("is_async", [True])
3939
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
40+
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
4041
@pytest.mark.asyncio
4142
async def test_multi_step(
4243
example_prompts,
@@ -49,6 +50,7 @@ async def test_multi_step(
4950
is_async: bool,
5051
num_logprobs: Optional[int],
5152
attention_backend: str,
53+
enable_chunked_prefill: bool,
5254
monkeypatch,
5355
) -> None:
5456
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
@@ -74,6 +76,10 @@ async def test_multi_step(
7476
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
7577
completions endpoint; `None` -> no logprobs
7678
"""
79+
if enable_chunked_prefill and \
80+
(pp_size > 1 or attention_backend != "FLASH_ATTN"):
81+
pytest.skip("Multi-step with Chunked-Prefill only supports"
82+
"PP=1 and FLASH_ATTN backend")
7783

7884
override_backend_env_variable(monkeypatch, attention_backend)
7985

@@ -93,6 +99,9 @@ async def test_multi_step(
9399
if eager_mode:
94100
ms_server_args.append("--enforce-eager")
95101

102+
if enable_chunked_prefill:
103+
ms_server_args.append("--enable-chunked-prefill")
104+
96105
distributed_args = [
97106
"--tensor-parallel-size",
98107
str(tp_size),

tests/multi_step/test_correctness_llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
@pytest.mark.parametrize("model", MODELS)
1717
@pytest.mark.parametrize("dtype", ["half"])
1818
@pytest.mark.parametrize("tp_size", [1])
19+
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
1920
@pytest.mark.parametrize("max_tokens", [5])
2021
@pytest.mark.parametrize("enforce_eager", [True])
2122
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@@ -28,6 +29,7 @@ def test_multi_step_llm(
2829
model: str,
2930
dtype: str,
3031
tp_size: int,
32+
enable_chunked_prefill: bool,
3133
max_tokens: int,
3234
enforce_eager: int,
3335
num_scheduler_steps: int,
@@ -51,6 +53,7 @@ def test_multi_step_llm(
5153
model: model under test (same for single- and multi-step engines)
5254
dtype: tensor datatype for engine to utilize
5355
tp_size: degree of tensor-parallelism
56+
enable_chunked_prefill: chunked-prefill on/off
5457
max_tokens: the maximum number of tokens to generate
5558
enforce_eager
5659
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
@@ -73,6 +76,7 @@ def test_multi_step_llm(
7376
gpu_memory_utilization=0.7,
7477
tensor_parallel_size=tp_size,
7578
use_v2_block_manager=True,
79+
enable_chunked_prefill=enable_chunked_prefill,
7680
num_scheduler_steps=num_scheduler_steps,
7781
) as vllm_model:
7882
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)

vllm/attention/backends/flash_attn.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,13 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
342342
)
343343
return self._cached_decode_metadata
344344

345-
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
345+
def advance_step(self,
346+
model_input: "ModelInputForGPUWithSamplingMetadata",
346347
sampled_token_ids: Optional[torch.Tensor],
347-
block_size: int, num_seqs: int, num_queries: int):
348+
block_size: int,
349+
num_seqs: int,
350+
num_queries: int,
351+
turn_prefills_into_decodes: bool = False):
348352
"""
349353
Update metadata in-place to advance one decode step.
350354
"""
@@ -355,6 +359,23 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
355359
assert num_seqs > num_queries
356360
assert self.use_cuda_graph
357361

362+
if turn_prefills_into_decodes:
363+
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
364+
# decodes are scheduled together. In the first step, all the
365+
# prefills turn into decodes. This update reflects that
366+
# conversion.
367+
assert self.num_decode_tokens + self.num_prefills == num_seqs
368+
self.num_decode_tokens += self.num_prefills
369+
self.num_prefills = 0
370+
self.num_prefill_tokens = 0
371+
self.max_prefill_seq_len = 0
372+
self.max_query_len = 1
373+
374+
self.slot_mapping = self.slot_mapping[:num_seqs]
375+
else:
376+
assert self.seq_lens is not None
377+
assert self.max_decode_seq_len == max(self.seq_lens)
378+
358379
assert self.num_prefills == 0
359380
assert self.num_prefill_tokens == 0
360381
assert self.num_decode_tokens == num_seqs
@@ -366,7 +387,6 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
366387
assert self.seq_lens_tensor.shape == (num_seqs, )
367388
assert self.max_query_len == 1
368389
assert self.max_prefill_seq_len == 0
369-
assert self.max_decode_seq_len == max(self.seq_lens)
370390

371391
assert self.query_start_loc is not None
372392
assert self.query_start_loc.shape == (num_queries + 1, )
@@ -706,8 +726,10 @@ def forward(
706726

707727
num_prefill_tokens = attn_metadata.num_prefill_tokens
708728
num_decode_tokens = attn_metadata.num_decode_tokens
709-
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
710-
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
729+
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
730+
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
731+
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
732+
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
711733

712734
# Query for decode. KV is not needed because it is already cached.
713735
decode_query = query[num_prefill_tokens:]

vllm/attention/backends/flashinfer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,18 +410,22 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]:
410410

411411
return self
412412

413-
def advance_step(
414-
self,
415-
model_input: "ModelInputForGPUWithSamplingMetadata",
416-
sampled_token_ids: Optional[torch.Tensor],
417-
block_size: int,
418-
num_seqs: int,
419-
num_queries: int,
420-
):
413+
def advance_step(self,
414+
model_input: "ModelInputForGPUWithSamplingMetadata",
415+
sampled_token_ids: Optional[torch.Tensor],
416+
block_size: int,
417+
num_seqs: int,
418+
num_queries: int,
419+
turn_prefills_into_decodes: bool = False):
421420
"""
422421
Update metadata in-place to advance one decode step.
423422
"""
424423

424+
assert not turn_prefills_into_decodes, \
425+
("Chunked prefill is not supported with flashinfer yet."
426+
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
427+
"specific parameter.")
428+
425429
assert num_seqs > 0
426430
assert num_queries > 0
427431
assert model_input.attn_metadata is not None

vllm/config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -983,9 +983,16 @@ def __init__(self,
983983
policy: str = "fcfs") -> None:
984984
if max_num_batched_tokens is None:
985985
if enable_chunked_prefill:
986-
# It is the values that have the best balance between ITL
987-
# and TTFT on A100. Note it is not optimized for throughput.
988-
max_num_batched_tokens = 512
986+
if num_scheduler_steps > 1:
987+
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
988+
# for now. Have max_num_batched_tokens set to max_model_len
989+
# so we don't reject sequences on account of a short
990+
# max_num_batched_tokens.
991+
max_num_batched_tokens = max(max_model_len, 2048)
992+
else:
993+
# It is the values that have the best balance between ITL
994+
# and TTFT on A100. Note it is not optimized for throughput.
995+
max_num_batched_tokens = 512
989996
else:
990997
# If max_model_len is too short, use 2048 as the default value
991998
# for higher throughput.

vllm/core/block/block_table.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ def __init__(
5555
self._num_full_slots = self._get_num_token_ids()
5656

5757
@staticmethod
58-
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
58+
def get_num_required_blocks(token_ids: List[int],
59+
block_size: int,
60+
num_lookahead_slots: int = 0) -> int:
5961
"""Calculates the minimum number of blocks required to store a given
60-
sequence of token IDs.
62+
sequence of token IDs along with any look-ahead slots that may be
63+
required (like in multi-step + chunked-prefill).
6164
6265
This assumes worst-case scenario, where every block requires a new
6366
allocation (e.g. ignoring prefix caching).
@@ -66,12 +69,14 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
6669
token_ids (List[int]): The sequence of token IDs to be stored.
6770
block_size (int): The maximum number of tokens that can be stored in
6871
a single block.
72+
num_lookahead_slots (int): look-ahead slots that the sequence may
73+
require.
6974
7075
Returns:
7176
int: The minimum number of blocks required to store the given
72-
sequence of token IDs.
77+
sequence of token IDs along with any required look-ahead slots.
7378
"""
74-
return cdiv(len(token_ids), block_size)
79+
return cdiv(len(token_ids) + num_lookahead_slots, block_size)
7580

7681
def allocate(self,
7782
token_ids: List[int],

vllm/core/block_manager_v1.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,15 @@ def __init__(
281281
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
282282
return 0 if seq is None else seq.n_blocks
283283

284-
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
284+
def can_allocate(self,
285+
seq_group: SequenceGroup,
286+
num_lookahead_slots: int = 0) -> AllocStatus:
285287
# FIXME(woosuk): Here we assume that all sequences in the group share
286288
# the same prompt. This may not be true for preempted sequences.
287289

290+
assert (num_lookahead_slots == 0
291+
), "lookahead allocation not supported in BlockSpaceManagerV1"
292+
288293
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
289294

290295
self_num_required_blocks = self._get_seq_num_required_blocks(

vllm/core/block_manager_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def __init__(
107107
self._last_access_blocks_tracker = LastAccessBlocksTracker(
108108
self.block_allocator)
109109

110-
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
110+
def can_allocate(self,
111+
seq_group: SequenceGroup,
112+
num_lookahead_slots: int = 0) -> AllocStatus:
111113
# FIXME(woosuk): Here we assume that all sequences in the group share
112114
# the same prompt. This may not be true for preempted sequences.
113115

@@ -117,6 +119,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
117119
num_required_blocks = BlockTable.get_num_required_blocks(
118120
seq.get_token_ids(),
119121
block_size=self.block_size,
122+
num_lookahead_slots=num_lookahead_slots,
120123
)
121124

122125
if seq_group.is_encoder_decoder():

vllm/core/embedding_model_block_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def __init__(
2121
) -> None:
2222
pass
2323

24-
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
24+
def can_allocate(self,
25+
seq_group: SequenceGroup,
26+
num_lookahead_slots: int = 0) -> AllocStatus:
2527
# Always return OK for dummy purposes
2628
return AllocStatus.OK
2729

vllm/core/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def get_block_space_manager_class(version: str):
4444
raise ValueError(f"Unknown version {version=}")
4545

4646
@abstractmethod
47-
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
47+
def can_allocate(self,
48+
seq_group: SequenceGroup,
49+
num_lookahead_slots: int = 0) -> AllocStatus:
4850
pass
4951

5052
@abstractmethod

0 commit comments

Comments
 (0)