|
15 | 15 | from vllm.distributed import get_pp_group
|
16 | 16 | from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
17 | 17 | from vllm.logger import init_logger
|
| 18 | +from vllm.model_executor import SamplingMetadataCache |
18 | 19 | from vllm.model_executor.layers.sampler import SamplerOutput
|
19 | 20 | from vllm.model_executor.model_loader import get_model
|
20 | 21 | from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
@@ -136,7 +137,7 @@ def build(self) -> ModelInputForXPU:
|
136 | 137 | (input_tokens, input_positions,
|
137 | 138 | attn_metadata) = self._prepare_decode(
|
138 | 139 | self.seq_group_metadata_list)
|
139 |
| - seq_lens = [] |
| 140 | + seq_lens = None |
140 | 141 | multi_modal_kwargs = None
|
141 | 142 |
|
142 | 143 | return self.model_input_cls(
|
@@ -406,6 +407,10 @@ def __init__(
|
406 | 407 | # Lazy initialization.
|
407 | 408 | self.model: nn.Module # Set after init_Model
|
408 | 409 |
|
| 410 | + self.sampling_metadata_cache: SamplingMetadataCache = \ |
| 411 | + SamplingMetadataCache() \ |
| 412 | + if self.parallel_config.pipeline_parallel_size == 1 else None |
| 413 | + |
409 | 414 | def load_model(self) -> None:
|
410 | 415 | with DeviceMemoryProfiler() as m:
|
411 | 416 | self.model = get_model(
|
@@ -540,12 +545,14 @@ def prepare_model_input(
|
540 | 545 | seq_group_metadata_list, finished_requests_ids)
|
541 | 546 | # Sampling metadata is only required for the final pp group
|
542 | 547 | generators = self.get_generators(finished_requests_ids)
|
543 |
| - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, |
544 |
| - model_input.seq_lens, |
545 |
| - model_input.query_lens, |
546 |
| - self.device, |
547 |
| - pin_memory=False, |
548 |
| - generators=generators) |
| 548 | + sampling_metadata = SamplingMetadata.prepare( |
| 549 | + seq_group_metadata_list, |
| 550 | + model_input.seq_lens, |
| 551 | + model_input.query_lens, |
| 552 | + self.device, |
| 553 | + pin_memory=False, |
| 554 | + generators=generators, |
| 555 | + cache=self.sampling_metadata_cache) |
549 | 556 |
|
550 | 557 | return dataclasses.replace(model_input,
|
551 | 558 | sampling_metadata=sampling_metadata,
|
|
0 commit comments