-
-
Notifications
You must be signed in to change notification settings - Fork 8.2k
[Core] Optimize SPMD architecture with delta + serialization optimization #7109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
d41f4c5
5741a83
d31d73f
36e786d
71e40c1
7e69242
0de9f23
64faf75
de4e43e
dc7c445
700e4a3
a906a9d
4af6699
1e6196b
0ea6e41
35e9637
912b88b
5bab192
eb2cb14
007fe86
ce64b8d
e8e29e1
751bdb1
d91aa78
06774d1
1af8dc2
6e6ac92
fa0d077
b5a88ec
d2e14ca
8be3c8e
c42c6c5
c55c8f6
2ba99e2
e2c850b
41ec6d1
925c928
9d3dee5
d041e9c
c4b3682
f938e00
32cb984
5a4f27e
c921877
ae1fb21
c3abcc5
3e1325e
652c258
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -691,8 +691,8 @@ def __init__( | |
self.tokenizer_pool_config = tokenizer_pool_config | ||
self.ray_workers_use_nsight = ray_workers_use_nsight | ||
self.placement_group = placement_group | ||
|
||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size | ||
|
||
if worker_use_ray: | ||
if self.distributed_executor_backend is None: | ||
self.distributed_executor_backend = "ray" | ||
|
@@ -788,6 +788,8 @@ class SchedulerConfig: | |
swapping. However, when the sequence group has multiple sequences | ||
(e.g., beam search), recomputation is not currently supported. In | ||
such a case, we use swapping instead. | ||
_use_delta: Private API. If used, scheduler sends delta data to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest to have a more clear naming, such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great suggestion! |
||
workers instead of an entire data. | ||
""" | ||
|
||
def __init__(self, | ||
|
@@ -799,7 +801,8 @@ def __init__(self, | |
delay_factor: float = 0.0, | ||
enable_chunked_prefill: bool = False, | ||
embedding_mode: Optional[bool] = False, | ||
preemption_mode: Optional[str] = None) -> None: | ||
preemption_mode: Optional[str] = None, | ||
_use_delta: bool = False) -> None: | ||
if max_num_batched_tokens is not None: | ||
self.max_num_batched_tokens = max_num_batched_tokens | ||
else: | ||
|
@@ -828,6 +831,7 @@ def __init__(self, | |
self.chunked_prefill_enabled = enable_chunked_prefill | ||
self.embedding_mode = embedding_mode | ||
self.preemption_mode = preemption_mode | ||
self._use_delta = _use_delta | ||
self._verify_args() | ||
|
||
def _verify_args(self) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,8 @@ | |
from vllm.lora.request import LoRARequest | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, | ||
SequenceGroupMetadata, SequenceStatus) | ||
SequenceGroupMetadata, SequenceGroupMetadataDelta, | ||
SequenceStatus) | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
@@ -974,41 +975,62 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: | |
seq_group.get_seqs(status=SequenceStatus.RUNNING))) | ||
|
||
do_sample = True | ||
if seq_group.is_prefill(): | ||
is_prompt = seq_group.is_prefill() | ||
# We should send the metadata to workers when the first prefill | ||
# is sent. Subsequent requests could be chunked prefill or decode. | ||
is_first_prefill = False | ||
if is_prompt: | ||
seqs = seq_group.get_seqs() | ||
# Prefill has only 1 sequence. | ||
assert len(seqs) == 1 | ||
num_computed_tokens = seqs[0].data.get_num_computed_tokens() | ||
is_first_prefill = num_computed_tokens == 0 | ||
# In the next iteration, all prompt tokens are not computed. | ||
# It means the prefill is chunked, and we don't need sampling. | ||
# NOTE: We use get_len instead of get_prompt_len because when | ||
# a sequence is preempted, prefill includes previous generated | ||
# output tokens. | ||
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < | ||
if (token_chunk_size + num_computed_tokens < | ||
seqs[0].data.get_len()): | ||
do_sample = False | ||
|
||
# It assumes the scheduled_seq_groups is ordered by | ||
# prefill < decoding. | ||
Comment on lines
1097
to
1098
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not related to this change, but this comment is outdated? With chunked prefill, decode can come first? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is still true. the order is prefill -> decode (literal list order, not the scheduling order) |
||
is_prompt = seq_group.is_prefill() | ||
seq_group_metadata = SequenceGroupMetadata( | ||
request_id=seq_group.request_id, | ||
is_prompt=is_prompt, | ||
seq_data=seq_data, | ||
sampling_params=seq_group.sampling_params, | ||
block_tables=block_tables, | ||
do_sample=do_sample, | ||
pooling_params=seq_group.pooling_params, | ||
token_chunk_size=token_chunk_size, | ||
lora_request=seq_group.lora_request, | ||
computed_block_nums=common_computed_block_nums, | ||
# `multi_modal_data` will only be present for the 1st comm | ||
# between engine and worker. | ||
# the subsequent comms can still use delta, but | ||
# `multi_modal_data` will be None. | ||
multi_modal_data=seq_group.multi_modal_data | ||
if scheduler_outputs.num_prefill_groups > 0 else None, | ||
prompt_adapter_request=seq_group.prompt_adapter_request, | ||
) | ||
if is_first_prefill or not self.scheduler_config._use_delta: | ||
seq_group_metadata = SequenceGroupMetadata( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NOTE: I removed _seq_group_metadata_cache.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @alexm-neuralmagic |
||
request_id=seq_group.request_id, | ||
is_prompt=is_prompt, | ||
seq_data=seq_data, | ||
sampling_params=seq_group.sampling_params, | ||
block_tables=block_tables, | ||
do_sample=do_sample, | ||
pooling_params=seq_group.pooling_params, | ||
token_chunk_size=token_chunk_size, | ||
lora_request=seq_group.lora_request, | ||
computed_block_nums=common_computed_block_nums, | ||
# `multi_modal_data` will only be present for the 1st comm | ||
# between engine and worker. | ||
# the subsequent comms can still use delta, but | ||
# `multi_modal_data` will be None. | ||
multi_modal_data=seq_group.multi_modal_data | ||
if scheduler_outputs.num_prefill_groups > 0 else None, | ||
prompt_adapter_request=seq_group.prompt_adapter_request, | ||
) | ||
else: | ||
# When SPMD mode is enabled, we only send delta data except for | ||
# the first request to reduce serialization cost. | ||
seq_data_delta = {} | ||
for id, data in seq_data.items(): | ||
seq_data_delta[id] = data.get_delta() | ||
seq_group_metadata = SequenceGroupMetadataDelta( | ||
seq_data_delta, | ||
seq_group.request_id, | ||
block_tables, | ||
is_prompt, | ||
do_sample=do_sample, | ||
token_chunk_size=token_chunk_size, | ||
computed_block_nums=common_computed_block_nums, | ||
) | ||
seq_group_metadata_list.append(seq_group_metadata) | ||
|
||
# Now that the batch has been created, we can assume all blocks in the | ||
|
@@ -1018,7 +1040,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: | |
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: | ||
self.block_manager.mark_blocks_as_computed( | ||
scheduled_seq_group.seq_group) | ||
|
||
return seq_group_metadata_list, scheduler_outputs | ||
|
||
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: | ||
|
Uh oh!
There was an error while loading. Please reload this page.