Skip to content

Commit 4a03479

Browse files
rkooo567LeiWang1999
authored andcommitted
[Core] Optimize SPMD architecture with delta + serialization optimization (vllm-project#7109)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 3e02a74 commit 4a03479

36 files changed

+727
-351
lines changed

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
2121
typing_extensions >= 4.10
2222
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
2323
pyzmq
24+
msgspec
2425
librosa # Required for audio processing
2526
soundfile # Required for audio processing
2627
gguf == 0.9.1

tests/basic_correctness/test_preemption.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
from prometheus_client import REGISTRY
1010

11+
import vllm.envs as envs
1112
from vllm import SamplingParams
1213
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
1314
ENABLE_ARTIFICIAL_PREEMPT)
@@ -24,6 +25,13 @@
2425
"tests/basic_correctness/test_preemption.py`")
2526

2627

28+
@pytest.fixture
29+
def worker_use_ray() -> bool:
30+
# When SPMD worker is used, use ray_use_worker=True
31+
# to test delta input optimization works with preemption.
32+
return envs.VLLM_USE_RAY_SPMD_WORKER
33+
34+
2735
@pytest.mark.parametrize("model", MODELS)
2836
@pytest.mark.parametrize("dtype", ["half"])
2937
@pytest.mark.parametrize("max_tokens", [96])
@@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
3644
dtype: str,
3745
max_tokens: int,
3846
chunked_prefill_token_size: int,
47+
worker_use_ray: bool,
3948
) -> None:
4049
"""Ensure that chunked prefill works with preemption."""
4150
max_num_seqs = min(chunked_prefill_token_size, 256)
@@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
5463
max_num_batched_tokens=max_num_batched_tokens,
5564
enable_chunked_prefill=enable_chunked_prefill,
5665
max_num_seqs=max_num_seqs,
66+
worker_use_ray=worker_use_ray,
5767
) as vllm_model:
5868
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
5969
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
@@ -79,6 +89,7 @@ def test_preemption(
7989
model: str,
8090
dtype: str,
8191
max_tokens: int,
92+
worker_use_ray: bool,
8293
) -> None:
8394
"""By default, recompute preemption is enabled"""
8495

@@ -89,6 +100,7 @@ def test_preemption(
89100
model,
90101
dtype=dtype,
91102
disable_log_stats=False,
103+
worker_use_ray=worker_use_ray,
92104
) as vllm_model:
93105
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
94106
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
@@ -132,6 +144,7 @@ def test_swap(
132144
dtype: str,
133145
max_tokens: int,
134146
beam_width: int,
147+
worker_use_ray: bool,
135148
) -> None:
136149
"""Use beam search enables swapping."""
137150
example_prompts = example_prompts[:1]
@@ -144,6 +157,7 @@ def test_swap(
144157
dtype=dtype,
145158
swap_space=10,
146159
disable_log_stats=False,
160+
worker_use_ray=worker_use_ray,
147161
) as vllm_model:
148162
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
149163
beam_width, max_tokens)
@@ -188,6 +202,7 @@ def test_swap_infeasible(
188202
dtype: str,
189203
max_tokens: int,
190204
beam_width: int,
205+
worker_use_ray: bool,
191206
) -> None:
192207
"""Verify infeasible swap request will be ignored."""
193208
BLOCK_SIZE = 16
@@ -204,6 +219,7 @@ def test_swap_infeasible(
204219
# decode blocks are not enough to finish.
205220
num_gpu_blocks_override=prefill_blocks + decode_blocks,
206221
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
222+
worker_use_ray=worker_use_ray,
207223
) as vllm_model:
208224
sampling_params = SamplingParams(n=beam_width,
209225
use_beam_search=True,
@@ -230,6 +246,7 @@ def test_preemption_infeasible(
230246
model: str,
231247
dtype: str,
232248
max_tokens: int,
249+
worker_use_ray: bool,
233250
) -> None:
234251
"""Verify infeasible preemption request will be ignored."""
235252
BLOCK_SIZE = 16
@@ -244,6 +261,7 @@ def test_preemption_infeasible(
244261
# ignored instead of hanging forever.
245262
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
246263
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
264+
worker_use_ray=worker_use_ray,
247265
) as vllm_model:
248266
sampling_params = SamplingParams(max_tokens=max_tokens,
249267
ignore_eos=True)

tests/core/test_serialization.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import msgspec
2+
3+
from vllm.executor.msgspec_utils import decode_hook, encode_hook
4+
from vllm.sequence import ExecuteModelRequest
5+
6+
from ..spec_decode.utils import create_batch
7+
8+
9+
def test_msgspec_serialization():
10+
num_lookahead_slots = 4
11+
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
12+
execute_model_req = ExecuteModelRequest(
13+
seq_group_metadata_list=seq_group_metadata_list,
14+
num_lookahead_slots=num_lookahead_slots,
15+
running_queue_size=4)
16+
17+
encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
18+
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
19+
dec_hook=decode_hook)
20+
req = decoder.decode(encoder.encode(execute_model_req))
21+
expected = execute_model_req.seq_group_metadata_list
22+
actual = req.seq_group_metadata_list
23+
assert (len(expected) == len(actual))
24+
expected = expected[0]
25+
actual = actual[0]
26+
27+
assert expected.block_tables == actual.block_tables
28+
assert expected.is_prompt == actual.is_prompt
29+
assert expected.request_id == actual.request_id
30+
assert (expected.seq_data[0].prompt_token_ids ==
31+
actual.seq_data[0].prompt_token_ids)
32+
assert (expected.seq_data[0].output_token_ids ==
33+
actual.seq_data[0].output_token_ids)

tests/distributed/test_basic_distributed_correctness.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
2323
reason="Need at least 2 GPUs to run the test.")
2424
@pytest.mark.parametrize(
25-
"model, distributed_executor_backend, attention_backend, test_suite", [
25+
"model, distributed_executor_backend, attention_backend, "
26+
"test_suite", [
2627
("facebook/opt-125m", "ray", "", "L4"),
2728
("facebook/opt-125m", "mp", "", "L4"),
2829
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),

tests/distributed/test_chunked_prefill_distributed.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
```
77
"""
88

9+
import os
10+
911
import pytest
1012

1113
from vllm.utils import cuda_device_count_stateless
@@ -30,6 +32,11 @@ def test_models(
3032
model: str,
3133
distributed_executor_backend: str,
3234
) -> None:
35+
if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa
36+
assert distributed_executor_backend == "ray"
37+
# test ray adag
38+
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
39+
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
3340

3441
dtype = "half"
3542
max_tokens = 5

tests/samplers/test_sampler.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import random
3+
from array import array
34
from typing import Dict, List, Optional, Tuple
45
from unittest.mock import Mock, patch
56

@@ -10,7 +11,8 @@
1011
from vllm.model_executor.layers.sampler import Sampler
1112
from vllm.model_executor.sampling_metadata import SamplingMetadata
1213
from vllm.model_executor.utils import set_random_seed
13-
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
14+
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
15+
SequenceData, SequenceGroupMetadata)
1416
from vllm.utils import Counter, is_pin_memory_available
1517

1618

@@ -56,7 +58,9 @@ def _do_sample(
5658
SequenceGroupMetadata(
5759
request_id=f"test_{i}",
5860
is_prompt=True,
59-
seq_data={0: SequenceData([1, 2, 3])},
61+
seq_data={
62+
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
63+
},
6064
sampling_params=sampling_params,
6165
block_tables={0: [1]},
6266
))
@@ -201,7 +205,8 @@ def create_sampling_params(min_tokens,
201205

202206
def create_sequence_data(num_input=3, num_generated=0):
203207
seq_data = SequenceData(
204-
random.choices(range(0, VOCAB_SIZE), k=num_input))
208+
array(VLLM_TOKEN_ID_ARRAY_TYPE,
209+
random.choices(range(0, VOCAB_SIZE), k=num_input)))
205210
if num_generated > 0:
206211
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
207212
k=num_generated)
@@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
504509
SequenceGroupMetadata(
505510
request_id=f"test_{i}",
506511
is_prompt=True,
507-
seq_data={0: SequenceData([1, 2, 3])},
512+
seq_data={
513+
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
514+
},
508515
sampling_params=sampling_params,
509516
block_tables={0: [1]},
510517
))
@@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
600607
SequenceGroupMetadata(
601608
request_id=f"test_{i}",
602609
is_prompt=True,
603-
seq_data={0: SequenceData([1, 2, 3])},
610+
seq_data={
611+
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
612+
},
604613
sampling_params=SamplingParams(
605614
temperature=1,
606615
top_k=top_k,
@@ -650,7 +659,11 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
650659
SequenceGroupMetadata(
651660
request_id=f"test_{i}",
652661
is_prompt=True,
653-
seq_data={0: SequenceData([1, 2, 3])},
662+
seq_data={
663+
0:
664+
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
665+
[1, 2, 3]))
666+
},
654667
sampling_params=sampling_params[i],
655668
block_tables={0: [1]},
656669
))

tests/spec_decode/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from array import array
12
from itertools import count
23
from typing import Callable, Dict, List, Optional
34
from typing import Sequence as GenericSequence
@@ -9,7 +10,8 @@
910
from vllm.engine.arg_utils import EngineArgs
1011
from vllm.model_executor.utils import set_random_seed
1112
from vllm.sampling_params import SamplingParams
12-
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
13+
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
14+
CompletionSequenceGroupOutput, Logprob,
1315
SamplerOutput, SequenceData, SequenceGroupMetadata,
1416
SequenceOutput)
1517
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
138140
seq_data={
139141
i:
140142
SequenceData(
141-
prompt_token_ids=prompt_token_ids[:],
142-
output_token_ids=cont_token_ids[:],
143+
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
144+
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
145+
cont_token_ids[:]),
143146
),
144147
},
145148
sampling_params=SamplingParams(temperature=0.0, ),

tests/test_logits_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from array import array
23
from typing import Tuple
34
from unittest.mock import patch
45

@@ -8,7 +9,8 @@
89
from vllm.model_executor.layers.logits_processor import LogitsProcessor
910
from vllm.model_executor.sampling_metadata import SamplingMetadata
1011
from vllm.model_executor.utils import set_random_seed
11-
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
12+
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
13+
SequenceData, SequenceGroupMetadata)
1214
from vllm.utils import is_pin_memory_available
1315

1416

@@ -69,7 +71,9 @@ def pick_ith(token_ids, logits):
6971
SequenceGroupMetadata(
7072
request_id=f"test_{i}",
7173
is_prompt=True,
72-
seq_data={0: SequenceData([1, 2, 3])},
74+
seq_data={
75+
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
76+
},
7377
sampling_params=SamplingParams(temperature=0,
7478
logits_processors=[pick_ith]),
7579
block_tables={0: [1]},

tests/test_sequence.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from array import array
2+
13
import pytest
24

3-
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
5+
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
6+
CompletionSequenceGroupOutput, SamplerOutput,
47
SequenceData, SequenceOutput)
58

69
from .core.utils import create_dummy_prompt
@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
5457

5558

5659
def test_sequence_data_prefill():
57-
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
60+
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
5861
assert seq_data.get_num_uncomputed_tokens() == 4
5962
assert seq_data.get_num_computed_tokens() == 0
6063
# advance by 2

tests/worker/test_encoder_decoder_model_runner.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from array import array
12
from typing import List
23

34
import pytest
45
import torch
56

67
from vllm.engine.arg_utils import EngineArgs
7-
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
8+
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
9+
SequenceData, SequenceGroupMetadata)
810
from vllm.utils import is_cpu
911
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
1012

@@ -125,10 +127,12 @@ def test_prepare_prompt(
125127
# make sure all tokens fit into one block
126128
seq_len = i % (model_runner.block_size - 1) + 1
127129
seq_lens.append(seq_len)
128-
seq_data = SequenceData(list(range(seq_len)))
130+
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
131+
range(seq_len)))
129132
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
130133
encoder_seq_lens.append(encoder_seq_len)
131-
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
134+
encoder_seq_data = SequenceData(
135+
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
132136
seq_group_metadata = SequenceGroupMetadata(
133137
request_id=f"test_{i}",
134138
is_prompt=True,
@@ -319,10 +323,12 @@ def test_prepare_decode(
319323
# make sure all tokens fit into one block
320324
seq_len = i % (model_runner.block_size - 1) + 1
321325
seq_lens.append(seq_len)
322-
seq_data = SequenceData(list(range(seq_len)))
326+
seq_data = SequenceData(
327+
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
323328
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
324329
encoder_seq_lens.append(encoder_seq_len)
325-
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
330+
encoder_seq_data = SequenceData(
331+
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
326332
seq_group_metadata = SequenceGroupMetadata(
327333
request_id=f"test_{i}",
328334
is_prompt=False,

0 commit comments

Comments
 (0)