Skip to content

Commit d5708b9

Browse files
author
rshaw@neuralmagic.com
committed
embedding model runner
1 parent 64a429d commit d5708b9

File tree

1 file changed

+20
-50
lines changed

1 file changed

+20
-50
lines changed

vllm/worker/embedding_model_runner.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,36 @@
11
import dataclasses
2-
from typing import Any, Dict, List, Optional, Tuple, Type, cast
2+
from typing import Any, Dict, List, Optional, Tuple, Type
33

44
import torch
55

6-
from vllm.attention.backends.abstract import AttentionBackend
76
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
87
ModelConfig, ObservabilityConfig, ParallelConfig,
98
PromptAdapterConfig, SchedulerConfig)
9+
from vllm.forward_context import set_forward_context
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.pooling_metadata import PoolingMetadata
1212
from vllm.multimodal import MultiModalInputs
1313
from vllm.pooling_params import PoolingParams
1414
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
1515
SequenceGroupMetadata)
16-
from vllm.worker.enc_dec_model_runner import (EncoderDecoderModelInput,
17-
EncoderDecoderModelRunnerBase)
18-
from vllm.worker.model_runner import ModelInputForGPUBuilder
16+
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
17+
ModelInputForGPUBuilder)
1918

2019
logger = init_logger(__name__)
2120

2221

2322
@dataclasses.dataclass(frozen=True)
24-
class EmbeddingModelInput(EncoderDecoderModelInput):
23+
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
2524
"""
2625
Used by the EmbeddingModelRunner.
2726
"""
2827
pooling_metadata: Optional["PoolingMetadata"] = None
2928

30-
@classmethod
31-
def from_broadcasted_tensor_dict(
32-
cls,
33-
tensor_dict: Dict[str, Any],
34-
attn_backend: Optional["AttentionBackend"] = None,
35-
) -> "EmbeddingModelInput":
36-
return cast(
37-
EmbeddingModelInput,
38-
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
3929

40-
41-
class EmbeddingModelRunner(EncoderDecoderModelRunnerBase[EmbeddingModelInput]):
42-
_model_input_cls: Type[EmbeddingModelInput] = EmbeddingModelInput
30+
class EmbeddingModelRunner(
31+
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
32+
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
33+
ModelInputForGPUWithPoolingMetadata)
4334
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
4435

4536
def __init__(
@@ -71,7 +62,7 @@ def __init__(
7162
@torch.inference_mode()
7263
def execute_model(
7364
self,
74-
model_input: EmbeddingModelInput,
65+
model_input: ModelInputForGPUWithPoolingMetadata,
7566
kv_caches: List[torch.Tensor],
7667
intermediate_tensors: Optional[IntermediateTensors] = None,
7768
num_steps: int = 1,
@@ -121,10 +112,6 @@ def execute_model(
121112
model_input.input_tokens,
122113
"positions":
123114
model_input.input_positions,
124-
"encoder_input_ids":
125-
model_input.encoder_input_tokens,
126-
"encoder_positions":
127-
model_input.encoder_input_positions,
128115
"kv_caches":
129116
kv_caches,
130117
"attn_metadata":
@@ -133,7 +120,8 @@ def execute_model(
133120
device=self.device),
134121
}
135122

136-
hidden_states = model_executable(**execute_model_kwargs)
123+
with set_forward_context(model_input.attn_metadata):
124+
hidden_states = model_executable(**execute_model_kwargs)
137125

138126
# Only perform pooling in the driver worker.
139127
if not self.is_driver_worker:
@@ -145,8 +133,10 @@ def execute_model(
145133
]
146134

147135
def make_model_input_from_broadcasted_tensor_dict(
148-
self, tensor_dict: Dict[str, Any]) -> EmbeddingModelInput:
149-
return EmbeddingModelInput.from_broadcasted_tensor_dict(
136+
self,
137+
tensor_dict: Dict[str,
138+
Any]) -> ModelInputForGPUWithPoolingMetadata:
139+
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
150140
tensor_dict,
151141
attn_backend=self.attn_backend,
152142
)
@@ -156,34 +146,14 @@ def prepare_model_input(
156146
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
157147
virtual_engine: int = 0,
158148
finished_requests_ids: Optional[List[str]] = None
159-
) -> EmbeddingModelInput:
149+
) -> ModelInputForGPUWithPoolingMetadata:
160150
assert seq_group_metadata_list is not None
161151
model_input = self._prepare_model_input_tensors(
162152
seq_group_metadata_list, finished_requests_ids)
163-
164-
(
165-
attn_metadata,
166-
encoder_input_tokens_tensor,
167-
encoder_input_positions_tensor,
168-
encoder_seq_lens,
169-
) = super()._prepare_encoder_model_input_tensors(
170-
seq_group_metadata_list, model_input)
171-
172-
model_input = dataclasses.replace(
173-
model_input,
174-
attn_metadata=attn_metadata,
175-
encoder_input_tokens=encoder_input_tokens_tensor,
176-
encoder_input_positions=encoder_input_positions_tensor,
177-
)
178-
179153
# Prepare PoolingMetadata.
180-
seq_lens = model_input.seq_lens\
181-
if not self.model_config.is_encoder_model \
182-
else encoder_seq_lens
183-
assert seq_lens is not None, "model is_encoder_model: "\
184-
f"{self.model_config.is_encoder_model}"
154+
assert model_input.seq_lens is not None
185155
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
186-
seq_lens)
156+
model_input.seq_lens)
187157

188158
return dataclasses.replace(model_input,
189159
pooling_metadata=pooling_metadata)
@@ -195,7 +165,7 @@ def _prepare_pooling(
195165
) -> PoolingMetadata:
196166
"""Prepare PoolingMetadata for the sequence group metadata list."""
197167
seq_groups: List[Tuple[List[int], PoolingParams]] = []
198-
for seq_group_metadata in seq_group_metadata_list:
168+
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
199169
seq_ids = list(seq_group_metadata.seq_data.keys())
200170
pooling_params = seq_group_metadata.pooling_params
201171
seq_groups.append((seq_ids, pooling_params))

0 commit comments

Comments
 (0)