1
1
import dataclasses
2
- from typing import Any , Dict , List , Optional , Tuple , Type , cast
2
+ from typing import Any , Dict , List , Optional , Tuple , Type
3
3
4
4
import torch
5
5
6
- from vllm .attention .backends .abstract import AttentionBackend
7
6
from vllm .config import (CacheConfig , DeviceConfig , LoadConfig , LoRAConfig ,
8
7
ModelConfig , ObservabilityConfig , ParallelConfig ,
9
8
PromptAdapterConfig , SchedulerConfig )
9
+ from vllm .forward_context import set_forward_context
10
10
from vllm .logger import init_logger
11
11
from vllm .model_executor .pooling_metadata import PoolingMetadata
12
12
from vllm .multimodal import MultiModalInputs
13
13
from vllm .pooling_params import PoolingParams
14
14
from vllm .sequence import (IntermediateTensors , PoolerOutput , SequenceData ,
15
15
SequenceGroupMetadata )
16
- from vllm .worker .enc_dec_model_runner import (EncoderDecoderModelInput ,
17
- EncoderDecoderModelRunnerBase )
18
- from vllm .worker .model_runner import ModelInputForGPUBuilder
16
+ from vllm .worker .model_runner import (GPUModelRunnerBase , ModelInputForGPU ,
17
+ ModelInputForGPUBuilder )
19
18
20
19
logger = init_logger (__name__ )
21
20
22
21
23
22
@dataclasses .dataclass (frozen = True )
24
- class EmbeddingModelInput ( EncoderDecoderModelInput ):
23
+ class ModelInputForGPUWithPoolingMetadata ( ModelInputForGPU ):
25
24
"""
26
25
Used by the EmbeddingModelRunner.
27
26
"""
28
27
pooling_metadata : Optional ["PoolingMetadata" ] = None
29
28
30
- @classmethod
31
- def from_broadcasted_tensor_dict (
32
- cls ,
33
- tensor_dict : Dict [str , Any ],
34
- attn_backend : Optional ["AttentionBackend" ] = None ,
35
- ) -> "EmbeddingModelInput" :
36
- return cast (
37
- EmbeddingModelInput ,
38
- super ().from_broadcasted_tensor_dict (tensor_dict , attn_backend ))
39
29
40
-
41
- class EmbeddingModelRunner (EncoderDecoderModelRunnerBase [EmbeddingModelInput ]):
42
- _model_input_cls : Type [EmbeddingModelInput ] = EmbeddingModelInput
30
+ class EmbeddingModelRunner (
31
+ GPUModelRunnerBase [ModelInputForGPUWithPoolingMetadata ]):
32
+ _model_input_cls : Type [ModelInputForGPUWithPoolingMetadata ] = (
33
+ ModelInputForGPUWithPoolingMetadata )
43
34
_builder_cls : Type [ModelInputForGPUBuilder ] = ModelInputForGPUBuilder
44
35
45
36
def __init__ (
@@ -71,7 +62,7 @@ def __init__(
71
62
@torch .inference_mode ()
72
63
def execute_model (
73
64
self ,
74
- model_input : EmbeddingModelInput ,
65
+ model_input : ModelInputForGPUWithPoolingMetadata ,
75
66
kv_caches : List [torch .Tensor ],
76
67
intermediate_tensors : Optional [IntermediateTensors ] = None ,
77
68
num_steps : int = 1 ,
@@ -121,10 +112,6 @@ def execute_model(
121
112
model_input .input_tokens ,
122
113
"positions" :
123
114
model_input .input_positions ,
124
- "encoder_input_ids" :
125
- model_input .encoder_input_tokens ,
126
- "encoder_positions" :
127
- model_input .encoder_input_positions ,
128
115
"kv_caches" :
129
116
kv_caches ,
130
117
"attn_metadata" :
@@ -133,7 +120,8 @@ def execute_model(
133
120
device = self .device ),
134
121
}
135
122
136
- hidden_states = model_executable (** execute_model_kwargs )
123
+ with set_forward_context (model_input .attn_metadata ):
124
+ hidden_states = model_executable (** execute_model_kwargs )
137
125
138
126
# Only perform pooling in the driver worker.
139
127
if not self .is_driver_worker :
@@ -145,8 +133,10 @@ def execute_model(
145
133
]
146
134
147
135
def make_model_input_from_broadcasted_tensor_dict (
148
- self , tensor_dict : Dict [str , Any ]) -> EmbeddingModelInput :
149
- return EmbeddingModelInput .from_broadcasted_tensor_dict (
136
+ self ,
137
+ tensor_dict : Dict [str ,
138
+ Any ]) -> ModelInputForGPUWithPoolingMetadata :
139
+ return ModelInputForGPUWithPoolingMetadata .from_broadcasted_tensor_dict (
150
140
tensor_dict ,
151
141
attn_backend = self .attn_backend ,
152
142
)
@@ -156,34 +146,14 @@ def prepare_model_input(
156
146
seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]],
157
147
virtual_engine : int = 0 ,
158
148
finished_requests_ids : Optional [List [str ]] = None
159
- ) -> EmbeddingModelInput :
149
+ ) -> ModelInputForGPUWithPoolingMetadata :
160
150
assert seq_group_metadata_list is not None
161
151
model_input = self ._prepare_model_input_tensors (
162
152
seq_group_metadata_list , finished_requests_ids )
163
-
164
- (
165
- attn_metadata ,
166
- encoder_input_tokens_tensor ,
167
- encoder_input_positions_tensor ,
168
- encoder_seq_lens ,
169
- ) = super ()._prepare_encoder_model_input_tensors (
170
- seq_group_metadata_list , model_input )
171
-
172
- model_input = dataclasses .replace (
173
- model_input ,
174
- attn_metadata = attn_metadata ,
175
- encoder_input_tokens = encoder_input_tokens_tensor ,
176
- encoder_input_positions = encoder_input_positions_tensor ,
177
- )
178
-
179
153
# Prepare PoolingMetadata.
180
- seq_lens = model_input .seq_lens \
181
- if not self .model_config .is_encoder_model \
182
- else encoder_seq_lens
183
- assert seq_lens is not None , "model is_encoder_model: " \
184
- f"{ self .model_config .is_encoder_model } "
154
+ assert model_input .seq_lens is not None
185
155
pooling_metadata = self ._prepare_pooling (seq_group_metadata_list ,
186
- seq_lens )
156
+ model_input . seq_lens )
187
157
188
158
return dataclasses .replace (model_input ,
189
159
pooling_metadata = pooling_metadata )
@@ -195,7 +165,7 @@ def _prepare_pooling(
195
165
) -> PoolingMetadata :
196
166
"""Prepare PoolingMetadata for the sequence group metadata list."""
197
167
seq_groups : List [Tuple [List [int ], PoolingParams ]] = []
198
- for seq_group_metadata in seq_group_metadata_list :
168
+ for i , seq_group_metadata in enumerate ( seq_group_metadata_list ) :
199
169
seq_ids = list (seq_group_metadata .seq_data .keys ())
200
170
pooling_params = seq_group_metadata .pooling_params
201
171
seq_groups .append ((seq_ids , pooling_params ))
0 commit comments