Skip to content

Commit ce1907f

Browse files
committed
update model runner
1 parent abd1a65 commit ce1907f

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

vllm/worker/model_runner.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
import torch.distributed
1515
import torch.nn as nn
1616

17-
import vllm.envs as envs
1817
from vllm.attention import AttentionMetadata, get_attn_backend
1918
from vllm.attention.backends.abstract import AttentionState
2019
from vllm.attention.backends.utils import CommonAttentionState
20+
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
2121
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
2222
ModelConfig, ObservabilityConfig, ParallelConfig,
2323
PromptAdapterConfig, SchedulerConfig)
@@ -47,8 +47,7 @@
4747
from vllm.sampling_params import SamplingParams
4848
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
4949
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
50-
flatten_2d_lists, is_hip, is_pin_memory_available,
51-
supports_dynamo)
50+
flatten_2d_lists, is_hip, is_pin_memory_available)
5251
from vllm.worker.model_runner_base import (
5352
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
5453
_add_attn_metadata_broadcastable_dict,
@@ -1125,15 +1124,6 @@ def load_model(self) -> None:
11251124
"provided. Defaulting to scaling factors of 1.0. "
11261125
"This may lead to less accurate results!")
11271126

1128-
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
1129-
from vllm.compilation.backends import vllm_backend
1130-
from vllm.plugins import get_torch_compile_backend
1131-
backend = get_torch_compile_backend() or vllm_backend
1132-
self.model = torch.compile(
1133-
self.model,
1134-
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1135-
backend=backend)
1136-
11371127
def save_sharded_state(
11381128
self,
11391129
path: str,
@@ -1426,7 +1416,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
14261416
batch_size_capture_list = [
14271417
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
14281418
]
1429-
1419+
if isinstance(self.model, TorchCompileWrapperWithCustomDispatcher):
1420+
self.model.set_sizes_to_specialize(batch_size_capture_list)
14301421
with self.attn_state.graph_capture(
14311422
max_batch_size), graph_capture() as graph_capture_context:
14321423
# NOTE: Capturing the largest batch size first may help reduce the

0 commit comments

Comments
 (0)