Skip to content

[Core] Optimize SPMD architecture with delta + serialization optimization #7109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d41f4c5
wip
rkooo567 Jul 25, 2024
5741a83
fix original arch issue
rkooo567 Jul 25, 2024
d31d73f
should work now.
rkooo567 Jul 25, 2024
36e786d
working
rkooo567 Jul 25, 2024
71e40c1
.
rkooo567 Jul 25, 2024
7e69242
pickle
rkooo567 Jul 25, 2024
0de9f23
msgpack optimization
rkooo567 Jul 27, 2024
64faf75
Merge branch 'main' into serialization-opt
rkooo567 Jul 29, 2024
de4e43e
ip
rkooo567 Jul 30, 2024
dc7c445
.
rkooo567 Jul 30, 2024
700e4a3
Merge branch 'main' into serialization-opt
rkooo567 Jul 30, 2024
a906a9d
msgspec migration done
rkooo567 Jul 31, 2024
4af6699
ip. preemption and chunked prefill not working yet.
rkooo567 Aug 1, 2024
1e6196b
working e2e
rkooo567 Aug 3, 2024
0ea6e41
Merge branch 'main-before-server' into spmd-and-pp
rkooo567 Aug 3, 2024
35e9637
working finally
rkooo567 Aug 3, 2024
912b88b
.
rkooo567 Aug 5, 2024
5bab192
working
rkooo567 Aug 5, 2024
eb2cb14
working
rkooo567 Aug 5, 2024
007fe86
fix a test failure.
rkooo567 Aug 5, 2024
ce64b8d
.
rkooo567 Aug 7, 2024
e8e29e1
fixed
rkooo567 Aug 10, 2024
751bdb1
addressed code review.
rkooo567 Aug 12, 2024
d91aa78
lint
rkooo567 Aug 12, 2024
06774d1
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 12, 2024
1af8dc2
ip
rkooo567 Aug 12, 2024
6e6ac92
all working
rkooo567 Aug 12, 2024
fa0d077
lint
rkooo567 Aug 12, 2024
b5a88ec
done
rkooo567 Aug 12, 2024
d2e14ca
code review.
rkooo567 Aug 12, 2024
8be3c8e
addressed code review.
rkooo567 Aug 13, 2024
c42c6c5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 13, 2024
c55c8f6
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 13, 2024
2ba99e2
lint fix
rkooo567 Aug 13, 2024
e2c850b
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 14, 2024
41ec6d1
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 14, 2024
925c928
fix lint
rkooo567 Aug 14, 2024
9d3dee5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 15, 2024
d041e9c
Addressed code review.
rkooo567 Aug 15, 2024
c4b3682
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 16, 2024
f938e00
fix pydantic not compatible to msggspec.Struct.
rkooo567 Aug 17, 2024
32cb984
addressed
rkooo567 Aug 17, 2024
5a4f27e
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 17, 2024
c921877
fixed
rkooo567 Aug 17, 2024
ae1fb21
temporarily use dataclass
rkooo567 Aug 17, 2024
c3abcc5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 17, 2024
3e1325e
Addressed code review.
rkooo567 Aug 18, 2024
652c258
lint
rkooo567 Aug 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
Expand All @@ -97,6 +98,7 @@ steps:
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
msgspec
18 changes: 18 additions & 0 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from prometheus_client import REGISTRY

import vllm.envs as envs
from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)
Expand All @@ -24,6 +25,13 @@
"tests/basic_correctness/test_preemption.py`")


@pytest.fixture
def worker_use_ray() -> bool:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return envs.VLLM_USE_RAY_SPMD_WORKER


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
Expand All @@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None:
"""Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256)
Expand All @@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
Expand All @@ -79,6 +89,7 @@ def test_preemption(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""By default, recompute preemption is enabled"""

Expand All @@ -89,6 +100,7 @@ def test_preemption(
model,
dtype=dtype,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
Expand Down Expand Up @@ -132,6 +144,7 @@ def test_swap(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
Expand All @@ -144,6 +157,7 @@ def test_swap(
dtype=dtype,
swap_space=10,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
Expand Down Expand Up @@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
Expand All @@ -230,6 +246,7 @@ def test_preemption_infeasible(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
Expand Down
41 changes: 21 additions & 20 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,29 @@

@pytest.mark.parametrize(
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
"MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
"MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"),
[
# (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
# (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
# (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
# (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
# (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
# (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
# (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
# (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
# (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
# (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
# (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
# (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
])
# (1, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
# (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
# (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
# (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
# (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
# (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
Expand Down
1 change: 1 addition & 0 deletions tests/prompts/example.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Write a short story about a robot that dreams for the first time.
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
2 changes: 0 additions & 2 deletions vllm/adapter_commons/request.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass


@dataclass
class AdapterRequest(ABC):
"""
Base class for adapter requests.
Expand Down
8 changes: 6 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,8 +691,8 @@ def __init__(
self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group

self.world_size = pipeline_parallel_size * self.tensor_parallel_size

if worker_use_ray:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
Expand Down Expand Up @@ -788,6 +788,8 @@ class SchedulerConfig:
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead.
_use_delta: Private API. If used, scheduler sends delta data to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to have a more clear naming, such as _send_delta_data or something like that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great suggestion!

workers instead of an entire data.
"""

def __init__(self,
Expand All @@ -799,7 +801,8 @@ def __init__(self,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None) -> None:
preemption_mode: Optional[str] = None,
_use_delta: bool = False) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
Expand Down Expand Up @@ -828,6 +831,7 @@ def __init__(self,
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self._use_delta = _use_delta
self._verify_args()

def _verify_args(self) -> None:
Expand Down
69 changes: 45 additions & 24 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)

logger = init_logger(__name__)

Expand Down Expand Up @@ -974,41 +975,62 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
seq_group.get_seqs(status=SequenceStatus.RUNNING)))

do_sample = True
if seq_group.is_prefill():
is_prompt = seq_group.is_prefill()
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill = False
if is_prompt:
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
num_computed_tokens = seqs[0].data.get_num_computed_tokens()
is_first_prefill = num_computed_tokens == 0
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
if (token_chunk_size + num_computed_tokens <
seqs[0].data.get_len()):
do_sample = False

# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
Comment on lines 1097 to 1098
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this change, but this comment is outdated? With chunked prefill, decode can come first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is still true. the order is prefill -> decode (literal list order, not the scheduling order)

is_prompt = seq_group.is_prefill()
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
do_sample=do_sample,
pooling_params=seq_group.pooling_params,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
if is_first_prefill or not self.scheduler_config._use_delta:
seq_group_metadata = SequenceGroupMetadata(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: I removed _seq_group_metadata_cache.

  • I ran this against the master, and I found there' no perf difference although I don't use the cache. I suspect it is because now SequenceGroupMetadata is a c object thanks to msgspec.Struct
  • __init__ doesn't work with msgspec.Struct
python3 benchmark_throughput.py --backend vllm --input-len 512 --output-len 256 --num-prompts 1000 --tensor-parallel 1
# this commit
Throughput: 20.72 requests/s, 15915.01 tokens/s
main
Throughput: 20.49 requests/s, 15735.44 tokens/s

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @alexm-neuralmagic

request_id=seq_group.request_id,
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
do_sample=do_sample,
pooling_params=seq_group.pooling_params,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
# When SPMD mode is enabled, we only send delta data except for
# the first request to reduce serialization cost.
seq_data_delta = {}
for id, data in seq_data.items():
seq_data_delta[id] = data.get_delta()
seq_group_metadata = SequenceGroupMetadataDelta(
seq_data_delta,
seq_group.request_id,
block_tables,
is_prompt,
do_sample=do_sample,
token_chunk_size=token_chunk_size,
computed_block_nums=common_computed_block_nums,
)
seq_group_metadata_list.append(seq_group_metadata)

# Now that the batch has been created, we can assume all blocks in the
Expand All @@ -1018,7 +1040,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)

return seq_group_metadata_list, scheduler_outputs

def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union

import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
Expand Down Expand Up @@ -815,7 +816,8 @@ def create_engine_config(self, ) -> EngineConfig:
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode,
)
_use_delta=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray))
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
Expand Down
1 change: 0 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def __init__(
cache_config.enable_prefix_caching,
)
# TODO(woosuk): Print more configs in debug mode.

self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
Expand Down
Loading