Skip to content

Commit 1cd50f1

Browse files
jikunshanggarg-amit
authored andcommitted
[Intel GPU] Fix xpu decode input (vllm-project#9145)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 22ef180 commit 1cd50f1

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

vllm/worker/xpu_model_runner.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.distributed import get_pp_group
1616
from vllm.inputs import INPUT_REGISTRY, InputRegistry
1717
from vllm.logger import init_logger
18+
from vllm.model_executor import SamplingMetadataCache
1819
from vllm.model_executor.layers.sampler import SamplerOutput
1920
from vllm.model_executor.model_loader import get_model
2021
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
@@ -136,7 +137,7 @@ def build(self) -> ModelInputForXPU:
136137
(input_tokens, input_positions,
137138
attn_metadata) = self._prepare_decode(
138139
self.seq_group_metadata_list)
139-
seq_lens = []
140+
seq_lens = None
140141
multi_modal_kwargs = None
141142

142143
return self.model_input_cls(
@@ -406,6 +407,10 @@ def __init__(
406407
# Lazy initialization.
407408
self.model: nn.Module # Set after init_Model
408409

410+
self.sampling_metadata_cache: SamplingMetadataCache = \
411+
SamplingMetadataCache() \
412+
if self.parallel_config.pipeline_parallel_size == 1 else None
413+
409414
def load_model(self) -> None:
410415
with DeviceMemoryProfiler() as m:
411416
self.model = get_model(
@@ -540,12 +545,14 @@ def prepare_model_input(
540545
seq_group_metadata_list, finished_requests_ids)
541546
# Sampling metadata is only required for the final pp group
542547
generators = self.get_generators(finished_requests_ids)
543-
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
544-
model_input.seq_lens,
545-
model_input.query_lens,
546-
self.device,
547-
pin_memory=False,
548-
generators=generators)
548+
sampling_metadata = SamplingMetadata.prepare(
549+
seq_group_metadata_list,
550+
model_input.seq_lens,
551+
model_input.query_lens,
552+
self.device,
553+
pin_memory=False,
554+
generators=generators,
555+
cache=self.sampling_metadata_cache)
549556

550557
return dataclasses.replace(model_input,
551558
sampling_metadata=sampling_metadata,

0 commit comments

Comments
 (0)