From f4a28868a116c9d2761bcb6aafe9554fa41df7ec Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 5 Sep 2024 19:35:09 +0000 Subject: [PATCH 01/10] Multi-Step Chunked-Prefill Support --- csrc/prepare_inputs/advance_step.cu | 6 +- .../multi_step/test_correctness_async_llm.py | 9 ++ tests/multi_step/test_correctness_llm.py | 4 + vllm/attention/backends/flash_attn.py | 32 +++- vllm/attention/backends/flashinfer.py | 20 ++- vllm/config.py | 13 +- vllm/core/block/block_table.py | 13 +- vllm/core/block_manager_v1.py | 7 +- vllm/core/block_manager_v2.py | 5 +- vllm/core/embedding_model_block_manager.py | 4 +- vllm/core/interfaces.py | 4 +- vllm/core/scheduler.py | 127 +++++++++++---- vllm/engine/arg_utils.py | 10 +- vllm/engine/async_llm_engine.py | 9 +- vllm/engine/llm_engine.py | 130 +++++++++++++-- vllm/engine/output_processor/multi_step.py | 1 + vllm/model_executor/sampling_metadata.py | 49 ++++++ vllm/sequence.py | 46 +++++- vllm/worker/model_runner_base.py | 6 + vllm/worker/multi_step_model_runner.py | 148 +++++++++++++++--- vllm/worker/multi_step_worker.py | 5 +- 21 files changed, 542 insertions(+), 106 deletions(-) diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index a9d08ca0dc14..dd03968f950d 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel( slot_mapping_ptr[cur_query_id] = slot_num; } -inline void verify_tensor(std::string const& name, torch::Tensor& t, +inline void verify_tensor(std::string const& name, torch::Tensor const& t, int64_t const size_0, int64_t const size_1, c10::ScalarType const type) { bool size_0_cond = true; @@ -211,7 +211,7 @@ void advance_step_flashinfer( printf(" num_seqs = %d\n", num_seqs); printf(" num_queries = %d\n", num_queries); printf(" block_size = %d\n", block_size); - printf(" block_tables.stride(0) = %d\n", block_tables.stride(0)); + printf(" block_tables.stride(0) = %ld\n", block_tables.stride(0)); } // Verify all tensors verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); @@ -303,4 +303,4 @@ void advance_step_flashinfer( num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, block_table_bound); -} \ No newline at end of file +} diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index a75a671e57f7..615549f2134a 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -37,6 +37,7 @@ @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("is_async", [True]) @pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) +@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) @pytest.mark.asyncio async def test_multi_step( example_prompts, @@ -49,6 +50,7 @@ async def test_multi_step( is_async: bool, num_logprobs: Optional[int], attention_backend: str, + enable_chunked_prefill: bool, monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling in an OpenAI-protocol @@ -74,6 +76,10 @@ async def test_multi_step( num_logprobs: corresponds to the `logprobs` argument to the OpenAI completions endpoint; `None` -> no logprobs """ + if enable_chunked_prefill and \ + (pp_size > 1 or attention_backend != "FLASH_ATTN"): + pytest.skip("Multi-step with Chunked-Prefill only supports" + "PP=1 and FLASH_ATTN backend") override_backend_env_variable(monkeypatch, attention_backend) @@ -93,6 +99,9 @@ async def test_multi_step( if eager_mode: ms_server_args.append("--enforce-eager") + if enable_chunked_prefill: + ms_server_args.append("--enable-chunked-prefill") + distributed_args = [ "--tensor-parallel-size", str(tp_size), diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index c5dc81cc2562..ff413e8e2da3 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -16,6 +16,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @@ -28,6 +29,7 @@ def test_multi_step_llm( model: str, dtype: str, tp_size: int, + enable_chunked_prefill: bool, max_tokens: int, enforce_eager: int, num_scheduler_steps: int, @@ -51,6 +53,7 @@ def test_multi_step_llm( model: model under test (same for single- and multi-step engines) dtype: tensor datatype for engine to utilize tp_size: degree of tensor-parallelism + enable_chunked_prefill: chunked-prefill on/off max_tokens: the maximum number of tokens to generate enforce_eager num_scheduler_steps: for multi-step scheduling, GPU-side steps per @@ -73,6 +76,7 @@ def test_multi_step_llm( gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, use_v2_block_manager=True, + enable_chunked_prefill=enable_chunked_prefill, num_scheduler_steps=num_scheduler_steps, ) as vllm_model: vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 084e8113cd42..3a7a8cd62ba5 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -342,9 +342,13 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: ) return self._cached_decode_metadata - def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int): + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ @@ -355,6 +359,23 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", assert num_seqs > num_queries assert self.use_cuda_graph + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 assert self.num_decode_tokens == num_seqs @@ -366,7 +387,6 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", assert self.seq_lens_tensor.shape == (num_seqs, ) assert self.max_query_len == 1 assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) assert self.query_start_loc is not None assert self.query_start_loc.shape == (num_queries + 1, ) @@ -704,8 +724,10 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3a602fbfbbc0..fa2f70dde9f2 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -410,18 +410,22 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return self - def advance_step( - self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - ): + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with flashinfer yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + assert num_seqs > 0 assert num_queries > 0 assert model_input.attn_metadata is not None diff --git a/vllm/config.py b/vllm/config.py index 8c65d99c4465..2d65ea2aff2b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -964,9 +964,16 @@ def __init__(self, send_delta_data: bool = False) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: - # It is the values that have the best balance between ITL - # and TTFT on A100. Note it is not optimized for throughput. - max_num_batched_tokens = 512 + if num_scheduler_steps > 1: + # Multi-step Chunked-Prefill doesn't allow prompt-chunking + # for now. Have max_num_batched_tokens set to max_model_len + # so we don't reject sequences on account of a short + # max_num_batched_tokens. + max_num_batched_tokens = max(max_model_len, 2048) + else: + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index c002dd1397f9..a9f4bd871dfd 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -55,9 +55,12 @@ def __init__( self._num_full_slots = self._get_num_token_ids() @staticmethod - def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: + def get_num_required_blocks(token_ids: List[int], + block_size: int, + num_lookahead_slots: int = 0) -> int: """Calculates the minimum number of blocks required to store a given - sequence of token IDs. + sequence of token IDs along with any look-ahead slots that may be + required (like in multi-step + chunked-prefill). This assumes worst-case scenario, where every block requires a new allocation (e.g. ignoring prefix caching). @@ -66,12 +69,14 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: token_ids (List[int]): The sequence of token IDs to be stored. block_size (int): The maximum number of tokens that can be stored in a single block. + num_lookahead_slots (int): look-ahead slots that the sequence may + require. Returns: int: The minimum number of blocks required to store the given - sequence of token IDs. + sequence of token IDs along with any required look-ahead slots. """ - return cdiv(len(token_ids), block_size) + return cdiv(len(token_ids) + num_lookahead_slots, block_size) def allocate(self, token_ids: List[int], diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 24ab9eb66194..a1f96707a6b5 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -281,10 +281,15 @@ def __init__( def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: return 0 if seq is None else seq.n_blocks - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerV1" + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) self_num_required_blocks = self._get_seq_num_required_blocks( diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 54818c7e3e9a..bb78b1e1c913 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -107,7 +107,9 @@ def __init__( self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. @@ -117,6 +119,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, + num_lookahead_slots=num_lookahead_slots, ) if seq_group.is_encoder_decoder(): diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index c47d7d8dfb07..476e043ecc52 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -21,7 +21,9 @@ def __init__( ) -> None: pass - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # Always return OK for dummy purposes return AllocStatus.OK diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 96f8dd851b2f..634671158730 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -44,7 +44,9 @@ def get_block_space_manager_class(version: str): raise ValueError(f"Unknown version {version=}") @abstractmethod - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3fa95f57b73..b3ef8b3f92a7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -522,7 +522,7 @@ def _schedule_running( ret.swapped_out.clear() ret.num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill=False) + is_prefill=False, enable_chunking=enable_chunking) ret.decode_seq_groups_list.clear() ret.prefill_seq_groups_list.clear() @@ -561,7 +561,7 @@ def _schedule_running( # NOTE(woosuk): Preemption happens only when there is no available # slot to keep all the sequence groups in the RUNNING state. - while not self._can_append_slots(seq_group): + while not self._can_append_slots(seq_group, enable_chunking): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) num_running_seqs = seq_group.get_max_num_running_seqs() @@ -611,7 +611,7 @@ def _schedule_running( if not cont_loop: break else: - self._append_slots(seq_group, blocks_to_copy) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = \ @@ -684,7 +684,8 @@ def _schedule_swapped( # If the sequence group cannot be swapped in, stop. is_prefill = seq_group.is_prefill() alloc_status = self.block_manager.can_swap_in( - seq_group, self._get_num_lookahead_slots(is_prefill)) + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) if alloc_status == AllocStatus.LATER: break elif alloc_status == AllocStatus.NEVER: @@ -727,7 +728,7 @@ def _schedule_swapped( curr_loras.add(lora_int_id) swapped_queue.popleft() self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( @@ -747,12 +748,13 @@ def _schedule_swapped( blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False), + is_prefill=False, enable_chunking=enable_chunking), infeasible_seq_groups=infeasible_seq_groups, ) def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if self.scheduler_config.chunked_prefill_enabled: + if self.scheduler_config.chunked_prefill_enabled and \ + not self.scheduler_config.is_multi_step: prompt_limit = self.scheduler_config.max_model_len else: prompt_limit = min(self.scheduler_config.max_model_len, @@ -826,13 +828,20 @@ def _schedule_prefills( waiting_queue.popleft() continue + num_lookahead_slots: int = 0 + if self.scheduler_config.is_multi_step and enable_chunking: + num_lookahead_slots = self._get_num_lookahead_slots( + True, enable_chunking) + # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate(seq_group) + can_allocate = self.block_manager.can_allocate( + seq_group, num_lookahead_slots=num_lookahead_slots) if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: logger.warning( - "Input prompt (%d tokens) is too long" + "Input prompt (%d tokens) + lookahead slots " + "({num_lookahead_slots}) is too long" " and exceeds the capacity of block_manager", num_new_tokens) for seq in waiting_seqs: @@ -866,9 +875,24 @@ def _schedule_prefills( curr_loras.add(lora_int_id) waiting_queue.popleft() self._allocate_and_set_running(seq_group) - seq_group.init_multi_step( - num_scheduler_steps=self._get_num_lookahead_slots( - is_prefill=True) + 1) + + if enable_chunking and self.scheduler_config.is_multi_step: + blocks_to_copy: List[Tuple[int, int]] = [] + # init_multi_step_from_lookahead_slots happens in append_slots + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + # This assert will trip when a copy-on-write happens. This is + # not a concern as the very first sequence-group block + # allocation happens above. Still, we have the assert to + # catch any edge-cases. + assert not blocks_to_copy + else: + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -883,7 +907,8 @@ def _schedule_prefills( return SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking)) def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. @@ -1076,7 +1101,8 @@ def _schedule(self) -> SchedulerOutputs: else: return self._schedule_default() - def _can_append_slots(self, seq_group: SequenceGroup) -> bool: + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ @@ -1087,12 +1113,16 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: self.artificial_preempt_cnt -= 1 return False - # Appending slots only occurs in decoding. - is_prefill = False + is_prefill = seq_group.is_prefill() + + # Appending prefill slots only happens chunked prefill is enabled. + assert self.scheduler_config.chunked_prefill_enabled or \ + not is_prefill return self.block_manager.can_append_slots( seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill, enable_chunking), ) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: @@ -1109,7 +1139,7 @@ def schedule( # such as self.running, self.swapped, and self.waiting. scheduler_start_time = time.perf_counter() - scheduler_outputs = self._schedule() + scheduler_outputs: SchedulerOutputs = self._schedule() now = time.time() if not self.cache_config.enable_prefix_caching: @@ -1306,11 +1336,10 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slots( - self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - ) -> None: + def _append_slots(self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False) -> None: """Appends new slots to the sequences in the given sequence group. Args: @@ -1321,11 +1350,25 @@ def _append_slots( int is the destination block index. This list is updated with the new source and destination block indices for the appended slots. + enable_chunking (bool): True if chunked prefill is enabled. """ - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) - seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + is_prefill: bool = seq_group.is_prefill() + num_lookahead_slots: int = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + + seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING + if self.scheduler_config.is_multi_step and enable_chunking: + # In multi-step chunked-prefill any sequence type can have + # slots appended. + seq_status = None + + for seq in seq_group.get_seqs(status=seq_status): cows = self.block_manager.append_slots(seq, num_lookahead_slots) if len(cows) > 0: blocks_to_copy.extend(cows) @@ -1436,16 +1479,32 @@ def _passed_delay(self, now: float) -> bool: passed_delay = True return passed_delay - def _get_num_lookahead_slots(self, is_prefill: bool) -> int: + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. Speculative decoding does not yet support prefill, so we do not perform lookahead allocation for prefill. + + When chunking is enabled with multi-step, we allocate lookahead slots + for the prefills for when the prefills turn into decodes in the first + step. """ if is_prefill: - return 0 + if self.scheduler_config.is_multi_step and enable_chunking: + # num_lookahead_slots was introduced in the context of decodes, + # in Speculative Decoding. + # When the num_scheduler_steps is 8, say, then the + # num_lookahead_slots is 7. Meaning, we are doing a 1-step of + # decode anyways and we wish to do 7 more. + # + # "lookaheads" for prefills, is introduced in support for + # Chunked-Prefill in Multi-Step. + return self.scheduler_config.num_lookahead_slots + 1 + else: + return 0 return self.scheduler_config.num_lookahead_slots @@ -1488,6 +1547,16 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if remaining_token_budget < num_new_tokens: num_new_tokens = (remaining_token_budget // block_size) * block_size + elif self.scheduler_config.is_multi_step: + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > remaining_token_budget \ + else num_new_tokens else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0d4559e37742..0efb0cbbf8be 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -980,9 +980,13 @@ def create_engine_config(self) -> EngineConfig: if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill: - raise ValueError("Chunked prefill is not supported with " - "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill and self.enable_prefix_caching: + raise ValueError("Multi-Step is not supported with " + "both Chunked-Prefill and Prefix-Caching " + "enabled together.") + if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: + raise ValueError("Multi-Step Chunked-Prefill is not supported " + "for pipeline-parallel-size > 1") # make sure num_lookahead_slots is set the higher value depending on # if we are using speculative decoding or multi-step diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 34e7e05341f0..36ffb8fb781c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -363,11 +363,18 @@ async def step_async( self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + ctx.append_output(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=allow_async_output_proc, - is_last_step=True) + is_last_step=True, + is_first_step_output=is_first_step_output) if outputs and allow_async_output_proc: assert len( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bd7b3250e31a..e0ea41e6de4f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -90,6 +90,12 @@ class OutputData(NamedTuple): scheduler_outputs: SchedulerOutputs is_async: bool is_last_step: bool + # Indicates if this output is from the first step of the + # multi-step. When multi-step is disabled, this is always + # set to True. + # is_first_step_output is invalid when `outputs` has + # outputs from multiple steps. + is_first_step_output: Optional[bool] skip: List[int] @@ -108,13 +114,15 @@ def __init__(self, multi_step_stream_outputs: bool = False): def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool): + is_last_step: bool, + is_first_step_output: Optional[bool]): self.output_queue.append( OutputData(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=is_async, is_last_step=is_last_step, + is_first_step_output=is_first_step_output, skip=[])) @@ -237,9 +245,10 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, multi_step_stream_outputs=%s, " - "enable_prefix_caching=%s, use_async_output_proc=%s, " - "use_cached_outputs=%s, mm_processor_kwargs=%s)", + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -270,6 +279,7 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, model_config.use_async_output_proc, @@ -903,8 +913,66 @@ def _process_model_outputs(self, ctx: The virtual engine context to work on request_id: If provided, then only this request is going to be processed - """ + + def update_prefill_num_computed_tokens( + seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, num_outputs: int, + is_first_step_output: Optional[bool]) -> None: + """ + seq_group: SequenceGroup - A prefill seq_group + seq_group_meta: SequenceGroupMetadata - Metadata of the given + prefill seq_group + num_outputs: int - number of output tokens being processed for the + given seq_group + is_first_step_output: Optional[bool] - + If multi-step is enabled and num_outputs is 1, this value + indicates if this outputs belongs to the first step in the + multi-step. + If multi-step is enabled and num_outputs > 1, this value + must be None, as num_outputs > 1 indicates that outputs from + all the steps in multi-step are submitted in a single burst. + When multi-step is disabled, this value is always True. + + When multi-step and chunked-prefill are enabled together, the + prefill sequence scheduled for multi-step execution turn into + decodes in the first step itself. This function accounts + for that conversion. + """ + + assert seq_group_meta.is_prompt + + token_chunk_size = seq_group_meta.token_chunk_size + + if num_outputs == 1: + assert is_first_step_output is not None + + if seq_group_meta.state.num_steps == 1: + assert is_first_step_output is True + seq_group.update_num_computed_tokens(token_chunk_size) + return + + # multi-step prefill is only supported when multi-step is + # enabled with chunked prefill + assert self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled + if is_first_step_output is True: + # This sequence is a prompt during the first step only. + seq_group.update_num_computed_tokens(token_chunk_size) + return + + assert is_first_step_output is None + + # multi-step prefill is only supported when multi-step is + # enabled with chunked prefill. Outputs from all the steps are + # submitted in a single burst. + assert self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled + assert num_outputs == seq_group_meta.state.num_steps, \ + f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa + # This sequence is a prompt during the first step only. + seq_group.update_num_computed_tokens(token_chunk_size) + now = time.time() if len(ctx.output_queue) == 0: @@ -915,20 +983,27 @@ def _process_model_outputs(self, # When we process only one request, no pop is required # (since later we will process all of the rest) (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, skip) = ctx.output_queue[0] + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] else: (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, skip) = ctx.output_queue.popleft() + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( scheduler_outputs.scheduled_seq_groups) - # Organize outputs by [step][sequence group] instead of - # [sequence group][step]. - if len(outputs) > 1: + has_multiple_outputs: bool = len(outputs) > 1 + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. outputs_by_sequence_group = create_output_by_sequence_group( outputs, num_seq_groups=len(seq_group_metadata_list)) + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None else: outputs_by_sequence_group = outputs @@ -964,14 +1039,17 @@ def _process_model_outputs(self, finished_before.append(i) continue - if len(outputs) > 1: + if has_multiple_outputs: output = outputs_by_sequence_group[i] else: output = [outputs_by_sequence_group[0][i]] - if not is_async: - seq_group.update_num_computed_tokens( - scheduled_seq_group.token_chunk_size) + if not is_async and seq_group_meta.is_prompt: + # Updates for all decodes happen when we actually append the + # token ids to the seq in process_outputs. + update_prefill_num_computed_tokens(seq_group, seq_group_meta, + len(output), + is_first_step_output) if outputs: for o in outputs: @@ -1105,8 +1183,18 @@ def _advance_to_next_step( if seq_group.is_finished(): continue - seq_group.update_num_computed_tokens( - seq_group_metadata.token_chunk_size) + if seq_group_metadata.is_prompt: + if self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled: + # Prompts are scheduled in multi-step only when + # chunking is enabled. These prompts turn into + # decodes after the very first step. Therefore, + # we skip the update to the num_computed_tokens + # here. + pass + else: + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( @@ -1118,6 +1206,7 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) + seq_group.update_num_computed_tokens(1) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -1270,12 +1359,19 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + # Add results to the output_queue ctx.append_output(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=allow_async_output_proc, - is_last_step=True) + is_last_step=True, + is_first_step_output=is_first_step_output) if outputs and allow_async_output_proc: assert len(outputs) == 1, ( diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b..eb955cecd942 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -169,6 +169,7 @@ def _process_seq_outputs(self, seq: Sequence, token_id=output_token_id, logprobs=output_logprob, ) + seq.data.update_num_computed_tokens(1) self._process_decode_and_stop(seq, sampling_params) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 97d36d31f2b1..0ecc4f7157c7 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -134,6 +134,9 @@ def __init__( num_prompts: int, skip_sampler_cpu_output: bool = False, reuse_sampling_tensors: bool = False, + # Used when multi-step is enabled with chunked-prefill. Refer to + # the comment in prepare_multistep_tensors. + selected_token_indices_multistep: Optional[torch.Tensor] = None ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices @@ -141,6 +144,52 @@ def __init__( self.num_prompts = num_prompts self.skip_sampler_cpu_output = skip_sampler_cpu_output self.reuse_sampling_tensors = reuse_sampling_tensors + self.selected_token_indices_multistep = selected_token_indices_multistep + + def prepare_multistep_tensors(self, num_queries: int, device: str, + pin_memory: bool): + """ + Invoked when Multi-Step is enabled with Chunked-Prefill. + When Multi-Step is enabled with Chunked-Prefill, the prompts and + decodes are scheduled together. + self.selected_token_indices is constructed for the first-step in + Multi-Step. However, the scheduled prompts, are fully processed + in the first-step and are processed as decodes in the rest of the steps. + This function prepares a "selected_token_indices" to be used + in the rest of the steps. + + Example: + Let 2 prompts and 2 decodes be scheduled together. Let the + num-tokens to process for the 2 prompts be 5 and 8 resply. + + In that case, self.sampled_token_indices will be, + [4, 12, 13, 14] as it is constructed for the first-step in + multi-step. + However, the prompts turns to decodes after the first-step + and the num-tokens for the previously-prompt sequences will + be 1 and 1 as they are decodes now. The self.sampled_token_indices + must be updated to [0,1,2,3]. + prepare_multistep_tensors prepares the "selected_token_indices" + to be used in steps 2-N. + """ + selected_token_indices_multistep = list(range(num_queries)) + self.selected_token_indices_multistep = \ + async_tensor_h2d(selected_token_indices_multistep, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + + def advance_step(self): + """ + Invoked when Multi-Step and Chunked-Prefill are enabled together. + The prefills that may have been scheduled, are fully processed in + the very first step and have turned into decodes. + Updated selected_token_indices to reflect that. Please refer to + the prepare_multistep_tensors docstring for an example. + """ + if self.selected_token_indices_multistep is not None: + # Swap to account for Single Step Prompts becoming Decodes + self.selected_token_indices = self.selected_token_indices_multistep @staticmethod def prepare( diff --git a/vllm/sequence.py b/vllm/sequence.py index 79e8a1f6244d..24e54bac1b44 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -729,10 +729,35 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ if self.prompt_adapter_request else 0 - def init_multi_step(self, num_scheduler_steps: int) -> None: - self.state.num_steps = num_scheduler_steps + def init_multi_step(self, num_steps: int) -> None: + self.state.num_steps = num_steps self.state.current_step = 0 + def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int, + num_scheduler_steps: int, + is_multi_step: bool, + enable_chunking: bool) -> None: + + if not is_multi_step: + self.init_multi_step(num_steps=num_scheduler_steps) + return + + # Multi-Step case + is_prefill = self.is_prefill() + + # The asserts below reflect the expectations of the current system. + if is_prefill and enable_chunking: + assert num_lookahead_slots == num_scheduler_steps + self.init_multi_step(num_steps=num_lookahead_slots) + else: + is_decode: bool = not is_prefill + # If it is a prefill, num_lookahead_slots must be 0 + assert num_lookahead_slots == 0 or is_decode + # If it is a decode, num_lookahead_slots + 1 must match + # the scheduler steps. + assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill + self.init_multi_step(num_steps=num_lookahead_slots + 1) + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -996,6 +1021,20 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ if self.prompt_adapter_request else 0 + # Multi-Step Chunked-Prefill property + @property + def is_single_step_prompt(self) -> bool: + # do_sample is true, only when the token_chunk_size matches the + # num_uncomputed_tokens of the sequence. This indicates that + # the prompt will finish processing in a single `execute_model` + # step. + return self.is_prompt and self.do_sample + + def get_first_seq_id(self) -> int: + # This is an efficient way of fetching the seq_id when + # we know this SequenceGroup has only one sequence. + return next(iter(self.seq_data)) + def apply_delta(self, sequence_group_metadata_delta: SequenceGroupMetadataDelta): for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): @@ -1008,7 +1047,8 @@ def apply_delta(self, def finish_step(self) -> None: assert self.state is not None - assert self.state.current_step < self.state.num_steps + assert self.state.current_step < self.state.num_steps, \ + f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa self.state.current_step += 1 diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 86883cf15244..1bb6a848390d 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -64,6 +64,8 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore from vllm.model_executor import SamplingMetadata selected_token_indices = tensor_dict.pop("selected_token_indices", None) + selected_token_indices_multistep = tensor_dict.pop( + "selected_token_indices_multistep", None) # An empty SamplingMetadata to signal that the worker should skip # sampling. if selected_token_indices is not None: @@ -72,6 +74,7 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore selected_token_indices=selected_token_indices, categorized_sample_indices=None, num_prompts=0, + selected_token_indices_multistep=selected_token_indices_multistep, ) return tensor_dict @@ -86,6 +89,9 @@ def _add_sampling_metadata_broadcastable_dict( if sampling_metadata is not None: tensor_dict["selected_token_indices"] = ( sampling_metadata.selected_token_indices) + if sampling_metadata.selected_token_indices_multistep is not None: + tensor_dict["selected_token_indices_multistep"] = ( + sampling_metadata.selected_token_indices_multistep) def _init_frozen_model_input_from_tensor_dict( diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index c7295f872f70..01f0fc2d569b 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -30,6 +30,14 @@ logger = init_logger(__name__) MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"] + +def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ + -> List[str]: + if chunked_prefill_enabled: + return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS + else: + return MULTI_STEP_ATTENTION_BACKENDS def seq_output_builder(): @@ -144,11 +152,13 @@ class StatefulModelInput(BroadcastableModelInput): is_multi_step: bool = True is_last_step: bool = False is_first_multi_step: bool = False + base_output_proc_callback: Optional[Callable] = None # ping-pong data structures for multi-step to wait on the previous step step_cuda_events: List[torch.cuda.Event] = field( default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) num_seqs: int = -1 num_queries: int = -1 + num_single_step_prefills: int = 0 def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: assert self.frozen_model_input is not None @@ -161,6 +171,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: 'is_first_multi_step': self.is_first_multi_step, 'num_seqs': self.num_seqs, 'num_queries': self.num_queries, + 'num_single_step_prefills': self.num_single_step_prefills, } tensor_dict.update(new_tensor_dict) return tensor_dict @@ -209,6 +220,50 @@ def add_sampler_output(self, sampled_token_ids=sampled_token_ids, pythonized=False)) + def maybe_advance_frozen_model_input(self): + """ + Advancing the datastructures of StatefulModelInput::frozen_model_input + is only required when prefills are scheduled with decodes to run in + multi-step. This advancement/correction is required to account for + the conversion of Prefills to Decodes after the first multi-step. + """ + if self.current_step != 1 or self.num_single_step_prefills == 0: + return + + assert self.frozen_model_input is not None + fmi = self.frozen_model_input + + # Truncate input_tokens + assert fmi.input_tokens is not None + assert fmi.input_tokens.shape[0] >= self.num_seqs + fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] + + # Update frozen_model_input::input_positons. + assert fmi.input_positions is not None + assert fmi.input_positions.shape[0] >= self.num_seqs + fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. + num_seqs] + + # Assert unsupported + assert fmi.lora_mapping is None + assert fmi.lora_requests is not None + assert len(fmi.lora_requests) == 0 + assert fmi.attn_metadata is not None + assert fmi.prompt_adapter_mapping is None + assert fmi.prompt_adapter_requests is not None + assert len(fmi.prompt_adapter_requests) == 0 + assert fmi.multi_modal_kwargs is not None + assert len(fmi.multi_modal_kwargs) == 0 + + self.frozen_model_input = dataclasses.replace( + self.frozen_model_input, + input_tokens=fmi_new_input_tokens, + input_positions=fmi_new_input_positions) + + if get_pp_group().is_last_rank: + assert self.frozen_model_input.sampling_metadata is not None + self.frozen_model_input.sampling_metadata.advance_step() + # MutableModelInputForGPUWithMultiStepMetadata is not subclass of # ModelInputForGPU but it wraps the actual input dataclass and adds multi-step @@ -220,6 +275,19 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): super().__init__(*args, **kwargs) + # Check attention backend support. + supported_attention_backends: List[str] = \ + _get_supported_attention_backends( + self.scheduler_config.chunked_prefill_enabled) + if self.attn_backend.get_name() not in supported_attention_backends: + ms_config_str: str = "Multi-Step + Chunked-Prefill" \ + if self.scheduler_config.chunked_prefill_enabled \ + else "Multi-Step" + raise ValueError( + f"{ms_config_str} not supported for attention backend: " + f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " + f"to a value from {supported_attention_backends}.") + # uses the base model runner to execute the model and wraps it with # multi-step logic self._base_model_runner: GPUModelRunnerBase = base_model_runner @@ -248,14 +316,32 @@ def prepare_model_input( virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> StatefulModelInput: - frozen_model_input = self._base_model_runner.prepare_model_input( - seq_group_metadata_list, virtual_engine, finished_requests_ids) + frozen_model_input: ModelInputForGPUWithSamplingMetadata = \ + self._base_model_runner.prepare_model_input( + seq_group_metadata_list, + virtual_engine, + finished_requests_ids) + + assert frozen_model_input.query_lens is not None + assert frozen_model_input.seq_lens is not None + assert frozen_model_input.attn_metadata is not None + num_queries = len(frozen_model_input.query_lens) + num_seqs = len(frozen_model_input.seq_lens) + num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills + + if get_pp_group().is_last_rank and num_single_step_prefills > 0: + assert frozen_model_input.sampling_metadata is not None + frozen_model_input.sampling_metadata.prepare_multistep_tensors( + num_queries=num_queries, + device=self.device, + pin_memory=self.pin_memory) model_input = StatefulModelInput( frozen_model_input=frozen_model_input, - num_seqs=len(frozen_model_input.seq_lens), - num_queries=len(frozen_model_input.query_lens), - ) + num_seqs=num_seqs, + num_queries=num_queries, + num_single_step_prefills=num_single_step_prefills) + return model_input def _async_process_outputs(self, model_input: StatefulModelInput, @@ -265,7 +351,7 @@ def _async_process_outputs(self, model_input: StatefulModelInput, output_proc_callback() cont = True - for model_output in model_input.cached_outputs: + for step_num, model_output in enumerate(model_input.cached_outputs): if not model_output.pythonized: model_output.maybe_pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids) @@ -276,7 +362,8 @@ def _async_process_outputs(self, model_input: StatefulModelInput, seq_group_metadata_list=ctx.seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, - is_last_step=False) + is_last_step=False, + is_first_step_output=step_num == 0) output_proc_callback() else: @@ -292,9 +379,8 @@ def _final_process_outputs(self, model_input: StatefulModelInput, has_async_callback = output_proc_callback is not None outputs = [] - for output_id in range(len(model_input.cached_outputs)): - output = model_input.cached_outputs[output_id] - is_last_step = output_id == len(model_input.cached_outputs) - 1 + for step_num, output in enumerate(model_input.cached_outputs): + is_last_step = step_num == len(model_input.cached_outputs) - 1 # For non-async case: # -- We simply add the outputs @@ -323,7 +409,8 @@ def _final_process_outputs(self, model_input: StatefulModelInput, seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, - is_last_step=False) + is_last_step=False, + is_first_step_output=step_num == 0) else: outputs.append(output.sampler_output) else: @@ -389,18 +476,27 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) - output_proc_callback = None + # frozen_model_input may have been updated + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + if model_input.base_output_proc_callback is None: + assert frozen_model_input is not None + model_input.base_output_proc_callback = \ + frozen_model_input.async_callback + if frozen_model_input.async_callback is not None: - output_proc_callback = frozen_model_input.async_callback - assert output_proc_callback is not None + assert model_input.base_output_proc_callback is not None async_callback = functools.partial( self._async_process_outputs, model_input=model_input, - output_proc_callback=output_proc_callback) + output_proc_callback=model_input.base_output_proc_callback) - frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input = dataclasses.replace( # type: ignore model_input.frozen_model_input, async_callback=async_callback) + # Update the local instance + frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None # Execute the model @@ -455,8 +551,8 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = self._final_process_outputs(model_input, - output_proc_callback) + outputs = self._final_process_outputs( + model_input, model_input.base_output_proc_callback) self.pythonization_cache.reset() return outputs @@ -484,11 +580,13 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: - if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS: - raise ValueError( - f"Multi-step not supported for attention backend: " - f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " - f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.") + + model_input.maybe_advance_frozen_model_input() + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.input_tokens is not None + assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs + assert frozen_model_input.attn_metadata is not None sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids num_seqs = model_input.num_seqs @@ -498,13 +596,15 @@ def _advance_step(self, model_input: StatefulModelInput, attn_metadata = frozen_model_input.attn_metadata assert attn_metadata is not None + turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ + model_input.num_single_step_prefills != 0 attn_metadata.advance_step( frozen_model_input, sampled_token_ids, self.block_size, num_seqs, num_queries, - ) + turn_prefills_into_decodes=turn_prefills_into_decodes) return model_input diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 562285f828cc..bf66f32d7d24 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -76,8 +76,9 @@ def _get_driver_input_and_broadcast( frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None assert frozen_model_input.attn_metadata is not None - # clear the cached decode metadata so that it can be recomputed on - # the workers + # clear the cached metadata so that it can be recomputed on + # the workers. + frozen_model_input.attn_metadata._cached_prefill_metadata = None frozen_model_input.attn_metadata._cached_decode_metadata = None model_input.is_first_multi_step = is_first_multi_step From 147a172f4dfc2d91d7c8a86d1bba49af88df994c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 26 Sep 2024 09:07:35 -0400 Subject: [PATCH 02/10] removing assert --- vllm/engine/arg_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0efb0cbbf8be..dd51fb3972f2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -980,10 +980,6 @@ def create_engine_config(self) -> EngineConfig: if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill and self.enable_prefix_caching: - raise ValueError("Multi-Step is not supported with " - "both Chunked-Prefill and Prefix-Caching " - "enabled together.") if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: raise ValueError("Multi-Step Chunked-Prefill is not supported " "for pipeline-parallel-size > 1") From ef3680dcba954deb21558ef292397f986fc64827 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 27 Sep 2024 07:17:08 -0400 Subject: [PATCH 03/10] wip --- tests/multi_step/test_correctness_llm.py | 107 +++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index ff413e8e2da3..134dabc407be 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -196,3 +196,110 @@ def test_multi_step_llm_w_prompt_logprobs( name_0="hf", name_1="vllm", ) + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("enable_chunked_prefill", [True]) +@pytest.mark.parametrize("enable_prefix_caching", [True]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs", [None, 5]) +def test_multi_step_llm_chunked_prefill_prefix_cache( + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + enable_chunked_prefill: bool, + enable_prefix_caching: bool, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step scheduling via sync LLM Engine. + + Set up a HuggingFace (HF) transformers model as a ground-truth reference. + + Prompt them with the same example prompts. + + Validate: + * Generated tokens match + * Generated logprobs are all very close + + Args: + hf_runner: HF transformers model runner fixture + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + enable_chunked_prefill: chunked-prefill on/off + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> 1 logprob returned. + """ + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + outputs_baseline = (vllm_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs)) + + outputs_baseline = (vllm_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs)) + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + enable_chunked_prefill=enable_chunked_prefill, + enable_prefix_caching=enable_prefix_caching, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + outputs_w_features = (vllm_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs)) + + if num_logprobs is None: + check_outputs_equal( + outputs_0_lst=outputs_baseline, + outputs_1_lst=outputs_w_features, + name_0="multi-step", + name_1="multi-step+features", + ) + else: + check_logprobs_close( + outputs_0_lst=outputs_baseline, + outputs_1_lst=outputs_w_features, + name_0="multi-step", + name_1="multi-step+features", + ) \ No newline at end of file From e6021c81cd0013293fa650c519fa13fef768cff9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 27 Sep 2024 08:04:34 -0400 Subject: [PATCH 04/10] wip --- tests/multi_step/test_correctness_llm.py | 141 ++++++++++++++++++++++- 1 file changed, 135 insertions(+), 6 deletions(-) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 134dabc407be..43dfb8867ed2 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -267,11 +267,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( if num_logprobs is None else vllm_model.generate_greedy_logprobs( prompts, max_tokens, num_logprobs)) - - outputs_baseline = (vllm_model.generate_greedy(prompts, max_tokens) - if num_logprobs is None else - vllm_model.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs)) with vllm_runner( model, @@ -302,4 +297,138 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( outputs_1_lst=outputs_w_features, name_0="multi-step", name_1="multi-step+features", - ) \ No newline at end of file + ) + + +from typing import List, Optional + +import pytest + +from tests.kernels.utils import override_backend_env_variable + +from ..models.utils import check_logprobs_close +from ..utils import (completions_with_server_args, get_client_text_generations, + get_client_text_logprob_generations) + +DEFAULT_SERVER_ARGS: List[str] = [ + "--disable-log-requests", + "--use-v2-block-manager", + "--worker-use-ray", + "--gpu-memory-utilization", + "0.85", + "--swap-space", + "16", +] + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize(("tp_size, pp_size"), [ + (1, 1), + (2, 2), +]) +@pytest.mark.parametrize("eager_mode", [False, True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("is_async", [True]) +@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) +@pytest.mark.parametrize("enable_chunked_prefill", [True]) +@pytest.mark.asyncio +async def test_multi_step_async( + example_prompts, + model: str, + tp_size: int, + pp_size: int, + eager_mode: int, + num_scheduler_steps: int, + num_prompts: int, + is_async: bool, + num_logprobs: Optional[int], + attention_backend: str, + enable_chunked_prefill: bool, + monkeypatch, +) -> None: + """Test vLLM engine with multi-step scheduling in an OpenAI-protocol + client/server environment. + + Set up an engine with single-step scheduling as a ground-truth reference. + + Send a completions API request to both engines with the same prompts. + + Validate: + * Generated tokens match + * Generated logprobs are all very close + + Args: + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + tp_size: degree of tensor-parallelism + pp_size: degree of pipeline-parallelism + eager_mode + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + """ + + override_backend_env_variable(monkeypatch, attention_backend) + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"] + ms_server_args = DEFAULT_SERVER_ARGS + \ + ["--num-scheduler-steps", f"{num_scheduler_steps}"] + + if not is_async: + ms_server_args += ["--disable-async-output-proc"] + + if eager_mode: + ms_server_args.append("--enforce-eager") + + ms_server_args.append("--enable-chunked-prefill") + ms_server_args.append("--enable-prefix-caching") + + distributed_args = [ + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + ] + + # Spin up client/server & issue completion API requests. + # Default `max_wait_seconds` is 240 but was empirically + # was raised 3x to 720 *just for this test* due to + # observed timeouts in GHA CI + ref_completions = await completions_with_server_args( + prompts, + model, + server_args + distributed_args, + num_logprobs, + max_wait_seconds=5 * 240) + test_completions = await completions_with_server_args( + prompts, + model, + ms_server_args + distributed_args, + num_logprobs, + max_wait_seconds=5 * 240) + + # Assert multi-step scheduling produces identical tokens + # to single-step scheduling. + ref_generations = get_client_text_generations(ref_completions) + test_generations = get_client_text_generations(test_completions) + assert ref_generations == test_generations + + # Assert multi-step scheduling produces nearly-identical logprobs + # to single-step scheduling. + ref_text_logprobs = get_client_text_logprob_generations(ref_completions) + test_text_logprobs = get_client_text_logprob_generations(test_completions) + check_logprobs_close( + outputs_0_lst=ref_text_logprobs, + outputs_1_lst=test_text_logprobs, + name_0="hf", + name_1="vllm", + ) \ No newline at end of file From 624344ed3c78a04de91f3d7c845f593c3f4aae64 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 30 Sep 2024 00:24:38 -0400 Subject: [PATCH 05/10] wip --- tests/multi_step/test_correctness_async_llm.py | 14 ++++++++++++++ vllm/entrypoints/api_server.py | 1 - 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index 615549f2134a..3ab6338dc80e 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -142,3 +142,17 @@ async def test_multi_step( name_0="hf", name_1="vllm", ) + +from vllm.scripts import main +import sys + +def test_multi_step_chunked_prefix(): + sys.argv = ['vllm', + 'serve', + 'JackFram/llama-160m', + '--num-scheduler-steps', + '8', + '--enable-prefix-caching', + '--enable-chunked-prefill'] + + main() \ No newline at end of file diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f3e80cab62a3..fcb31713f05a 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -159,5 +159,4 @@ async def run_server(args: Namespace, parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() - asyncio.run(run_server(args)) From 8937c92e543a278d057a8508a0311da29bf814aa Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 30 Sep 2024 11:30:48 -0400 Subject: [PATCH 06/10] num_new_tokens fix --- .../multi_step/test_correctness_async_llm.py | 12 ++--- vllm/core/scheduler.py | 44 ++++++++++++------- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index b4794b1aa207..06ce859e32ce 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -231,14 +231,14 @@ async def test_multi_step_pp_smoke( def test_multi_step_chunked_prefix(): - # sys.argv = [ - # 'vllm', 'serve', 'JackFram/llama-160m', '--num-scheduler-steps', '8', - # '--enable-prefix-caching', '--enable-chunked-prefill' - # ] - sys.argv = [ 'vllm', 'serve', 'JackFram/llama-160m', '--num-scheduler-steps', '8', - '--enable-chunked-prefill' + '--enable-prefix-caching', '--enable-chunked-prefill' ] + # sys.argv = [ + # 'vllm', 'serve', 'JackFram/llama-160m', '--num-scheduler-steps', '8', + # '--enable-chunked-prefill' + # ] + main() diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5b7587d15084..13a4bf1abc21 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1608,21 +1608,35 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() if self.cache_config.enable_prefix_caching: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block size - # to avoid partial block matching. - block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > remaining_token_budget \ + else num_new_tokens + else: + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. + block_size = self.cache_config.block_size + remainder = budget.token_budget % block_size + if remainder != 0: + raise ValueError( + "When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") + if remaining_token_budget < num_new_tokens: + num_new_tokens = (remaining_token_budget // + block_size) * block_size elif self.scheduler_config.is_multi_step: if num_new_tokens > self._get_prompt_limit(seq_group): # If the seq_group is in prompt-stage, pass the From 846560ae92fe3bc16898cd3890295022c0ab1f25 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 2 Oct 2024 01:03:45 -0400 Subject: [PATCH 07/10] reverted to old test_correctness_async_llm.py; test_correctness_llm.py has contrived multistep+chunked pref+pref cache test --- .../multi_step/test_correctness_async_llm.py | 18 -- tests/multi_step/test_correctness_llm.py | 175 +++--------------- 2 files changed, 28 insertions(+), 165 deletions(-) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index 06ce859e32ce..000c923ef3e6 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -224,21 +224,3 @@ async def test_multi_step_pp_smoke( test_generations = get_client_text_generations(test_completions) assert ref_generations == test_generations - - -from vllm.scripts import main -import sys - - -def test_multi_step_chunked_prefix(): - sys.argv = [ - 'vllm', 'serve', 'JackFram/llama-160m', '--num-scheduler-steps', '8', - '--enable-prefix-caching', '--enable-chunked-prefill' - ] - - # sys.argv = [ - # 'vllm', 'serve', 'JackFram/llama-160m', '--num-scheduler-steps', '8', - # '--enable-chunked-prefill' - # ] - - main() diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index a3c4ebd53ffd..178b85ac9025 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -249,11 +249,19 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( completions endpoint; `None` -> 1 logprob returned. """ - prompts = example_prompts - if len(prompts) < num_prompts: - prompts = prompts * ((num_prompts // len(prompts)) + 1) - prompts = prompts[:num_prompts] - assert len(prompts) == num_prompts + assert len(example_prompts) >= 2 + example_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' + 'inference and serving engine for LLMs.\n') # 24 tok + example_prompts[1] = ( + 'Briefly describe the major milestones in the ' + 'development of artificial intelligence from 1950 to 2020.\n' + ) # 30 tok + + if len(example_prompts) < num_prompts: + example_prompts = (example_prompts * + ((num_prompts // len(example_prompts)) + 1)) + example_prompts = example_prompts[:num_prompts] + assert len(example_prompts) == num_prompts with vllm_runner( model, @@ -263,11 +271,15 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( tensor_parallel_size=tp_size, use_v2_block_manager=True, num_scheduler_steps=num_scheduler_steps, + max_model_len=48, + max_num_batched_tokens=48, + max_num_seqs=4, + block_size=16, ) as vllm_model: - outputs_baseline = (vllm_model.generate_greedy(prompts, max_tokens) - if num_logprobs is None else + outputs_baseline = (vllm_model.generate_greedy( + example_prompts, max_tokens) if num_logprobs is None else vllm_model.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs)) + example_prompts, max_tokens, num_logprobs)) with vllm_runner( model, @@ -279,11 +291,15 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( enable_chunked_prefill=enable_chunked_prefill, enable_prefix_caching=enable_prefix_caching, num_scheduler_steps=num_scheduler_steps, + max_model_len=48, + max_num_batched_tokens=48, + max_num_seqs=4, + block_size=16, ) as vllm_model: - outputs_w_features = (vllm_model.generate_greedy(prompts, max_tokens) - if num_logprobs is None else + outputs_w_features = (vllm_model.generate_greedy( + example_prompts, max_tokens) if num_logprobs is None else vllm_model.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs)) + example_prompts, max_tokens, num_logprobs)) if num_logprobs is None: check_outputs_equal( @@ -298,139 +314,4 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( outputs_1_lst=outputs_w_features, name_0="multi-step", name_1="multi-step+features", - ) - - -from typing import List, Optional - -import pytest - -from tests.kernels.utils import override_backend_env_variable - -from ..models.utils import check_logprobs_close -from ..utils import (completions_with_server_args, get_client_text_generations, - get_client_text_logprob_generations) - -DEFAULT_SERVER_ARGS: List[str] = [ - "--disable-log-requests", - "--use-v2-block-manager", - "--worker-use-ray", - "--gpu-memory-utilization", - "0.85", - "--swap-space", - "16", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize(("tp_size, pp_size"), [ - (1, 1), - (2, 2), -]) -@pytest.mark.parametrize("eager_mode", [False, True]) -@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) -@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("is_async", [True]) -@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) -@pytest.mark.parametrize("enable_chunked_prefill", [True]) -@pytest.mark.asyncio -async def test_multi_step_async( - example_prompts, - model: str, - tp_size: int, - pp_size: int, - eager_mode: int, - num_scheduler_steps: int, - num_prompts: int, - is_async: bool, - num_logprobs: Optional[int], - attention_backend: str, - enable_chunked_prefill: bool, - monkeypatch, -) -> None: - """Test vLLM engine with multi-step scheduling in an OpenAI-protocol - client/server environment. - - Set up an engine with single-step scheduling as a ground-truth reference. - - Send a completions API request to both engines with the same prompts. - - Validate: - * Generated tokens match - * Generated logprobs are all very close - - Args: - example_prompts: test fixture providing example prompts - model: model under test (same for single- and multi-step engines) - tp_size: degree of tensor-parallelism - pp_size: degree of pipeline-parallelism - eager_mode - num_scheduler_steps: for multi-step scheduling, GPU-side steps per - GPU -> CPU output transfer - num_prompts: number of example prompts under test - num_logprobs: corresponds to the `logprobs` argument to the OpenAI - completions endpoint; `None` -> no logprobs - """ - - override_backend_env_variable(monkeypatch, attention_backend) - - prompts = example_prompts - if len(prompts) < num_prompts: - prompts = prompts * ((num_prompts // len(prompts)) + 1) - prompts = prompts[:num_prompts] - assert len(prompts) == num_prompts - - server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"] - ms_server_args = DEFAULT_SERVER_ARGS + \ - ["--num-scheduler-steps", f"{num_scheduler_steps}"] - - if not is_async: - ms_server_args += ["--disable-async-output-proc"] - - if eager_mode: - ms_server_args.append("--enforce-eager") - - ms_server_args.append("--enable-chunked-prefill") - ms_server_args.append("--enable-prefix-caching") - - distributed_args = [ - "--tensor-parallel-size", - str(tp_size), - "--pipeline-parallel-size", - str(pp_size), - ] - - # Spin up client/server & issue completion API requests. - # Default `max_wait_seconds` is 240 but was empirically - # was raised 3x to 720 *just for this test* due to - # observed timeouts in GHA CI - ref_completions = await completions_with_server_args( - prompts, - model, - server_args + distributed_args, - num_logprobs, - max_wait_seconds=5 * 240) - test_completions = await completions_with_server_args( - prompts, - model, - ms_server_args + distributed_args, - num_logprobs, - max_wait_seconds=5 * 240) - - # Assert multi-step scheduling produces identical tokens - # to single-step scheduling. - ref_generations = get_client_text_generations(ref_completions) - test_generations = get_client_text_generations(test_completions) - assert ref_generations == test_generations - - # Assert multi-step scheduling produces nearly-identical logprobs - # to single-step scheduling. - ref_text_logprobs = get_client_text_logprob_generations(ref_completions) - test_text_logprobs = get_client_text_logprob_generations(test_completions) - check_logprobs_close( - outputs_0_lst=ref_text_logprobs, - outputs_1_lst=test_text_logprobs, - name_0="hf", - name_1="vllm", - ) + ) \ No newline at end of file From 92d80f7b519245acb4043d4fe68d5cdb29bfe8d4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 2 Oct 2024 01:10:45 -0400 Subject: [PATCH 08/10] pulling in some code from main --- vllm/entrypoints/api_server.py | 1 + vllm/model_executor/sampling_metadata.py | 49 ------------------------ vllm/worker/model_runner_base.py | 6 --- 3 files changed, 1 insertion(+), 55 deletions(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index fcb31713f05a..f3e80cab62a3 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -159,4 +159,5 @@ async def run_server(args: Namespace, parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() + asyncio.run(run_server(args)) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 2f8ae16c31c6..ee02368bec8a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -134,9 +134,6 @@ def __init__( num_prompts: int, skip_sampler_cpu_output: bool = False, reuse_sampling_tensors: bool = False, - # Used when multi-step is enabled with chunked-prefill. Refer to - # the comment in prepare_multistep_tensors. - selected_token_indices_multistep: Optional[torch.Tensor] = None ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices @@ -144,52 +141,6 @@ def __init__( self.num_prompts = num_prompts self.skip_sampler_cpu_output = skip_sampler_cpu_output self.reuse_sampling_tensors = reuse_sampling_tensors - self.selected_token_indices_multistep = selected_token_indices_multistep - - def prepare_multistep_tensors(self, num_queries: int, device: str, - pin_memory: bool): - """ - Invoked when Multi-Step is enabled with Chunked-Prefill. - When Multi-Step is enabled with Chunked-Prefill, the prompts and - decodes are scheduled together. - self.selected_token_indices is constructed for the first-step in - Multi-Step. However, the scheduled prompts, are fully processed - in the first-step and are processed as decodes in the rest of the steps. - This function prepares a "selected_token_indices" to be used - in the rest of the steps. - - Example: - Let 2 prompts and 2 decodes be scheduled together. Let the - num-tokens to process for the 2 prompts be 5 and 8 resply. - - In that case, self.sampled_token_indices will be, - [4, 12, 13, 14] as it is constructed for the first-step in - multi-step. - However, the prompts turns to decodes after the first-step - and the num-tokens for the previously-prompt sequences will - be 1 and 1 as they are decodes now. The self.sampled_token_indices - must be updated to [0,1,2,3]. - prepare_multistep_tensors prepares the "selected_token_indices" - to be used in steps 2-N. - """ - selected_token_indices_multistep = list(range(num_queries)) - self.selected_token_indices_multistep = \ - async_tensor_h2d(selected_token_indices_multistep, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory) - - def advance_step(self): - """ - Invoked when Multi-Step and Chunked-Prefill are enabled together. - The prefills that may have been scheduled, are fully processed in - the very first step and have turned into decodes. - Updated selected_token_indices to reflect that. Please refer to - the prepare_multistep_tensors docstring for an example. - """ - if self.selected_token_indices_multistep is not None: - # Swap to account for Single Step Prompts becoming Decodes - self.selected_token_indices = self.selected_token_indices_multistep @staticmethod def prepare( diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 1bb6a848390d..86883cf15244 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -64,8 +64,6 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore from vllm.model_executor import SamplingMetadata selected_token_indices = tensor_dict.pop("selected_token_indices", None) - selected_token_indices_multistep = tensor_dict.pop( - "selected_token_indices_multistep", None) # An empty SamplingMetadata to signal that the worker should skip # sampling. if selected_token_indices is not None: @@ -74,7 +72,6 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore selected_token_indices=selected_token_indices, categorized_sample_indices=None, num_prompts=0, - selected_token_indices_multistep=selected_token_indices_multistep, ) return tensor_dict @@ -89,9 +86,6 @@ def _add_sampling_metadata_broadcastable_dict( if sampling_metadata is not None: tensor_dict["selected_token_indices"] = ( sampling_metadata.selected_token_indices) - if sampling_metadata.selected_token_indices_multistep is not None: - tensor_dict["selected_token_indices_multistep"] = ( - sampling_metadata.selected_token_indices_multistep) def _init_frozen_model_input_from_tensor_dict( From 12d09144c852f8950d366eb49a54c1ee067adf93 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 2 Oct 2024 01:42:47 -0400 Subject: [PATCH 09/10] sync test cleanup --- tests/multi_step/test_correctness_llm.py | 93 +++++++++++++++++------- 1 file changed, 66 insertions(+), 27 deletions(-) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 178b85ac9025..f45428675bde 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -1,5 +1,6 @@ # Test the LLMEngine with multi-step-decoding +import copy from typing import Optional import pytest @@ -201,8 +202,6 @@ def test_multi_step_llm_w_prompt_logprobs( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("tp_size", [1]) -@pytest.mark.parametrize("enable_chunked_prefill", [True]) -@pytest.mark.parametrize("enable_prefix_caching", [True]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @@ -214,32 +213,45 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( model: str, dtype: str, tp_size: int, - enable_chunked_prefill: bool, - enable_prefix_caching: bool, max_tokens: int, enforce_eager: int, num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], ) -> None: - """Test vLLM engine with multi-step scheduling via sync LLM Engine. + """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. - Set up a HuggingFace (HF) transformers model as a ground-truth reference. + Set up contrived scenario which tests for a possible failure mode of + scheduling with multi-step+"single-step chunked prefill"+APC - Prompt them with the same example prompts. + "single-step chunked prefill" here refers to the current vLLM multi-step+ + chunked-prefill implementation, which requires that a prefill may only + be scheduled in the same step as decodes if the prefill prompt fits in a + single chunk (note that "complete" multi-step+chunked-prefill would allow + a prefill to span multiple chunks & multiple steps but that is not yet + the case.) - Validate: - * Generated tokens match - * Generated logprobs are all very close + "APC" is short for "automatic prefix caching". + + This test creates a scenario where the scheduler must decide whether/how + to schedule a prefill with a prompt that exceeds the available token budget. + The correct behavior for multi-step+"single-step chunked prefill"+APC is to + put off scheduling the prefill until a future step. + + Validate that: + * Multi-step kernels do not raise an exception due to incorrect scheduler + behavior + * Generated tokens match between + multi-step+"single-step chunked prefill"+APC and + single-step scheduling. + * (If logprobs are enabled) check logprobs are close enough Args: - hf_runner: HF transformers model runner fixture vllm_runner: vLLM model runner fixture example_prompts: test fixture providing example prompts model: model under test (same for single- and multi-step engines) dtype: tensor datatype for engine to utilize tp_size: degree of tensor-parallelism - enable_chunked_prefill: chunked-prefill on/off max_tokens: the maximum number of tokens to generate enforce_eager num_scheduler_steps: for multi-step scheduling, GPU-side steps per @@ -249,20 +261,44 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( completions endpoint; `None` -> 1 logprob returned. """ + # Set up contrived test for correct scheduling behavior with + # multi-step+"single-step chunked prefill"+APC. + # + # Assume block_size=16 + # + # Assume max_num_batched_tokens=48 + # => Per-step token budget=48 + # + # 1. Scheduler schedules 0th prompt (24 tokens) + # => Remaining token budget=24 + # 2. Scheduler attempts to schedule 1st prompt (30 tokens) + # * 30 tokens exceeds 24 token remaining budget + # * Correct behavior: do not schedule this prompt in this step + # * Incorrect behavior: schedule prompt chunk + # * `do_sample=False` for this prompt in this step + # * Chunk size = (remaining tokens // block size) * block size + # + # The Incorrect scheduling behavior - if it occurs - will cause an exception + # in the model runner resulting from `do_sample=False`. assert len(example_prompts) >= 2 - example_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' - 'inference and serving engine for LLMs.\n') # 24 tok - example_prompts[1] = ( + challenge_prompts = copy.deepcopy(example_prompts) + challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' + 'inference and serving engine for LLMs.\n' + ) # 24 tok + challenge_prompts[1] = ( 'Briefly describe the major milestones in the ' 'development of artificial intelligence from 1950 to 2020.\n' ) # 30 tok - if len(example_prompts) < num_prompts: - example_prompts = (example_prompts * - ((num_prompts // len(example_prompts)) + 1)) - example_prompts = example_prompts[:num_prompts] - assert len(example_prompts) == num_prompts + # If necessary, adjust the length of `challenge_prompts` to match + # `num_prompts` + if len(challenge_prompts) < num_prompts: + challenge_prompts = (challenge_prompts * + ((num_prompts // len(challenge_prompts)) + 1)) + challenge_prompts = challenge_prompts[:num_prompts] + assert len(challenge_prompts) == num_prompts + # Single-step scheduler baseline with vllm_runner( model, dtype=dtype, @@ -277,10 +313,11 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( block_size=16, ) as vllm_model: outputs_baseline = (vllm_model.generate_greedy( - example_prompts, max_tokens) if num_logprobs is None else + challenge_prompts, max_tokens) if num_logprobs is None else vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs)) + challenge_prompts, max_tokens, num_logprobs)) + # multi-step+"single-step chunked prefill"+APC with vllm_runner( model, dtype=dtype, @@ -288,8 +325,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, use_v2_block_manager=True, - enable_chunked_prefill=enable_chunked_prefill, - enable_prefix_caching=enable_prefix_caching, + enable_chunked_prefill=True, + enable_prefix_caching=True, num_scheduler_steps=num_scheduler_steps, max_model_len=48, max_num_batched_tokens=48, @@ -297,11 +334,12 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( block_size=16, ) as vllm_model: outputs_w_features = (vllm_model.generate_greedy( - example_prompts, max_tokens) if num_logprobs is None else + challenge_prompts, max_tokens) if num_logprobs is None else vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs)) + challenge_prompts, max_tokens, num_logprobs)) if num_logprobs is None: + # No-logprobs test check_outputs_equal( outputs_0_lst=outputs_baseline, outputs_1_lst=outputs_w_features, @@ -309,9 +347,10 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( name_1="multi-step+features", ) else: + # Yes-logprobs test check_logprobs_close( outputs_0_lst=outputs_baseline, outputs_1_lst=outputs_w_features, name_0="multi-step", name_1="multi-step+features", - ) \ No newline at end of file + ) From b5277c9f8b33e1e585d4c8b03d7722fb363fba46 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 2 Oct 2024 01:48:55 -0400 Subject: [PATCH 10/10] refactoring scheduler improvements --- vllm/core/scheduler.py | 57 +++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 13a4bf1abc21..f3a5016d0e62 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1607,37 +1607,16 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, # in a decode phase. Do not chunk. if enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() - if self.cache_config.enable_prefix_caching: - if self.scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens - else: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. - block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError( - "When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size - elif self.scheduler_config.is_multi_step: + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps if num_new_tokens > self._get_prompt_limit(seq_group): # If the seq_group is in prompt-stage, pass the # num_new_tokens as-is so the caller can ignore @@ -1647,6 +1626,22 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, num_new_tokens = 0 \ if num_new_tokens > remaining_token_budget \ else num_new_tokens + elif self.cache_config.enable_prefix_caching: + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. + block_size = self.cache_config.block_size + remainder = budget.token_budget % block_size + if remainder != 0: + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") + if remaining_token_budget < num_new_tokens: + num_new_tokens = (remaining_token_budget // + block_size) * block_size else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens