|
14 | 14 | import torch.distributed
|
15 | 15 | import torch.nn as nn
|
16 | 16 |
|
17 |
| -import vllm.envs as envs |
18 | 17 | from vllm.attention import AttentionMetadata, get_attn_backend
|
19 | 18 | from vllm.attention.backends.abstract import AttentionState
|
20 | 19 | from vllm.attention.backends.utils import CommonAttentionState
|
| 20 | +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher |
21 | 21 | from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
22 | 22 | ModelConfig, ObservabilityConfig, ParallelConfig,
|
23 | 23 | PromptAdapterConfig, SchedulerConfig)
|
|
47 | 47 | from vllm.sampling_params import SamplingParams
|
48 | 48 | from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
49 | 49 | from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
|
50 |
| - flatten_2d_lists, is_hip, is_pin_memory_available, |
51 |
| - supports_dynamo) |
| 50 | + flatten_2d_lists, is_hip, is_pin_memory_available) |
52 | 51 | from vllm.worker.model_runner_base import (
|
53 | 52 | ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
54 | 53 | _add_attn_metadata_broadcastable_dict,
|
@@ -1125,15 +1124,6 @@ def load_model(self) -> None:
|
1125 | 1124 | "provided. Defaulting to scaling factors of 1.0. "
|
1126 | 1125 | "This may lead to less accurate results!")
|
1127 | 1126 |
|
1128 |
| - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): |
1129 |
| - from vllm.compilation.backends import vllm_backend |
1130 |
| - from vllm.plugins import get_torch_compile_backend |
1131 |
| - backend = get_torch_compile_backend() or vllm_backend |
1132 |
| - self.model = torch.compile( |
1133 |
| - self.model, |
1134 |
| - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, |
1135 |
| - backend=backend) |
1136 |
| - |
1137 | 1127 | def save_sharded_state(
|
1138 | 1128 | self,
|
1139 | 1129 | path: str,
|
@@ -1426,7 +1416,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
1426 | 1416 | batch_size_capture_list = [
|
1427 | 1417 | bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
1428 | 1418 | ]
|
1429 |
| - |
| 1419 | + if isinstance(self.model, TorchCompileWrapperWithCustomDispatcher): |
| 1420 | + self.model.set_sizes_to_specialize(batch_size_capture_list) |
1430 | 1421 | with self.attn_state.graph_capture(
|
1431 | 1422 | max_batch_size), graph_capture() as graph_capture_context:
|
1432 | 1423 | # NOTE: Capturing the largest batch size first may help reduce the
|
|
0 commit comments