Skip to content

Commit 6531fd6

Browse files
jikunshangsumitd2
authored andcommitted
[Intel GPU] Fix xpu decode input (vllm-project#9145)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 4e60a33 commit 6531fd6

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

vllm/worker/xpu_model_runner.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.distributed import get_pp_group
1616
from vllm.inputs import INPUT_REGISTRY, InputRegistry
1717
from vllm.logger import init_logger
18+
from vllm.model_executor import SamplingMetadataCache
1819
from vllm.model_executor.layers.sampler import SamplerOutput
1920
from vllm.model_executor.model_loader import get_model
2021
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
@@ -136,7 +137,7 @@ def build(self) -> ModelInputForXPU:
136137
(input_tokens, input_positions,
137138
attn_metadata) = self._prepare_decode(
138139
self.seq_group_metadata_list)
139-
seq_lens = []
140+
seq_lens = None
140141
multi_modal_kwargs = None
141142

142143
return self.model_input_cls(
@@ -390,6 +391,10 @@ def __init__(
390391
# Lazy initialization.
391392
self.model: nn.Module # Set after init_Model
392393

394+
self.sampling_metadata_cache: SamplingMetadataCache = \
395+
SamplingMetadataCache() \
396+
if self.parallel_config.pipeline_parallel_size == 1 else None
397+
393398
def load_model(self) -> None:
394399
with DeviceMemoryProfiler() as m:
395400
self.model = get_model(
@@ -524,12 +529,14 @@ def prepare_model_input(
524529
seq_group_metadata_list, finished_requests_ids)
525530
# Sampling metadata is only required for the final pp group
526531
generators = self.get_generators(finished_requests_ids)
527-
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
528-
model_input.seq_lens,
529-
model_input.query_lens,
530-
self.device,
531-
pin_memory=False,
532-
generators=generators)
532+
sampling_metadata = SamplingMetadata.prepare(
533+
seq_group_metadata_list,
534+
model_input.seq_lens,
535+
model_input.query_lens,
536+
self.device,
537+
pin_memory=False,
538+
generators=generators,
539+
cache=self.sampling_metadata_cache)
533540

534541
return dataclasses.replace(model_input,
535542
sampling_metadata=sampling_metadata,

0 commit comments

Comments
 (0)