|
15 | 15 | from vllm.distributed import get_pp_group
|
16 | 16 | from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
17 | 17 | from vllm.logger import init_logger
|
| 18 | +from vllm.model_executor import SamplingMetadataCache |
18 | 19 | from vllm.model_executor.layers.sampler import SamplerOutput
|
19 | 20 | from vllm.model_executor.model_loader import get_model
|
20 | 21 | from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
@@ -136,7 +137,7 @@ def build(self) -> ModelInputForXPU:
|
136 | 137 | (input_tokens, input_positions,
|
137 | 138 | attn_metadata) = self._prepare_decode(
|
138 | 139 | self.seq_group_metadata_list)
|
139 |
| - seq_lens = [] |
| 140 | + seq_lens = None |
140 | 141 | multi_modal_kwargs = None
|
141 | 142 |
|
142 | 143 | return self.model_input_cls(
|
@@ -390,6 +391,10 @@ def __init__(
|
390 | 391 | # Lazy initialization.
|
391 | 392 | self.model: nn.Module # Set after init_Model
|
392 | 393 |
|
| 394 | + self.sampling_metadata_cache: SamplingMetadataCache = \ |
| 395 | + SamplingMetadataCache() \ |
| 396 | + if self.parallel_config.pipeline_parallel_size == 1 else None |
| 397 | + |
393 | 398 | def load_model(self) -> None:
|
394 | 399 | with DeviceMemoryProfiler() as m:
|
395 | 400 | self.model = get_model(
|
@@ -524,12 +529,14 @@ def prepare_model_input(
|
524 | 529 | seq_group_metadata_list, finished_requests_ids)
|
525 | 530 | # Sampling metadata is only required for the final pp group
|
526 | 531 | generators = self.get_generators(finished_requests_ids)
|
527 |
| - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, |
528 |
| - model_input.seq_lens, |
529 |
| - model_input.query_lens, |
530 |
| - self.device, |
531 |
| - pin_memory=False, |
532 |
| - generators=generators) |
| 532 | + sampling_metadata = SamplingMetadata.prepare( |
| 533 | + seq_group_metadata_list, |
| 534 | + model_input.seq_lens, |
| 535 | + model_input.query_lens, |
| 536 | + self.device, |
| 537 | + pin_memory=False, |
| 538 | + generators=generators, |
| 539 | + cache=self.sampling_metadata_cache) |
533 | 540 |
|
534 | 541 | return dataclasses.replace(model_input,
|
535 | 542 | sampling_metadata=sampling_metadata,
|
|
0 commit comments