12
12
SchedulerConfig )
13
13
from vllm .logger import init_logger
14
14
from vllm .model_executor import SamplingMetadata
15
+ from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
15
16
from vllm .model_executor .layers .sampler import SamplerOutput
16
17
from vllm .model_executor .model_loader import get_model
17
18
from vllm .multimodal import (MULTIMODAL_REGISTRY , BatchedTensorInputs ,
18
19
MultiModalInputs )
19
- from vllm .sequence import IntermediateTensors , SequenceGroupMetadata
20
+ from vllm .sequence import (IntermediateTensors , SequenceData ,
21
+ SequenceGroupMetadata )
20
22
from vllm .utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS , make_tensor_with_pad
21
23
from vllm .worker .model_runner_base import (
22
24
ModelRunnerBase , ModelRunnerInputBase , ModelRunnerInputBuilderBase ,
@@ -145,6 +147,38 @@ def build(self) -> ModelInputForCPU:
145
147
query_lens = seq_lens ,
146
148
)
147
149
150
+ def _compute_multi_modal_input (self , seq_data : SequenceData , mm_data ,
151
+ computed_len : int ):
152
+ mm_kwargs = self .multi_modal_input_mapper (mm_data )
153
+
154
+ # special processing for mrope position deltas.
155
+ mrope_positions = None
156
+ if self .runner .model_is_mrope :
157
+ image_grid_thw = mm_kwargs .get ("image_grid_thw" , None )
158
+ video_grid_thw = mm_kwargs .get ("video_grid_thw" , None )
159
+ assert image_grid_thw is not None or video_grid_thw is not None , (
160
+ "mrope embedding type requires multi-modal input mapper "
161
+ "returns 'image_grid_thw' or 'video_grid_thw'." )
162
+
163
+ hf_config = self .runner .model_config .hf_config
164
+ token_ids = seq_data .get_token_ids ()
165
+
166
+ mrope_positions , mrope_position_delta = \
167
+ MRotaryEmbedding .get_input_positions (
168
+ token_ids ,
169
+ image_grid_thw = image_grid_thw ,
170
+ video_grid_thw = video_grid_thw ,
171
+ image_token_id = hf_config .image_token_id ,
172
+ video_token_id = hf_config .video_token_id ,
173
+ vision_start_token_id = hf_config .vision_start_token_id ,
174
+ vision_end_token_id = hf_config .vision_end_token_id ,
175
+ spatial_merge_size = hf_config .vision_config .
176
+ spatial_merge_size ,
177
+ context_len = computed_len ,
178
+ )
179
+ seq_data .mrope_position_delta = mrope_position_delta
180
+ return mm_kwargs , mrope_positions
181
+
148
182
def _prepare_prompt (
149
183
self ,
150
184
seq_group_metadata_list : List [SequenceGroupMetadata ],
@@ -155,6 +189,8 @@ def _prepare_prompt(
155
189
input_positions : List [int ] = []
156
190
# The number of original input tokens of each sequence
157
191
num_orig_input_tokens_list : List [int ] = []
192
+ input_mrope_positions : List [List [int ]] = [[] for _ in range (3 )]
193
+
158
194
slot_mapping : List [int ] = []
159
195
seq_lens : List [int ] = []
160
196
multi_modal_inputs_list : List [MultiModalInputs ] = []
@@ -173,17 +209,23 @@ def _prepare_prompt(
173
209
seq_lens .append (seq_len ) # Prompt token num
174
210
input_tokens .extend (prompt_tokens ) # Token ids
175
211
212
+ mrope_positions = None
213
+ if (mm_data := seq_group_metadata .multi_modal_data ):
214
+ mm_kwargs , mrope_positions = self ._compute_multi_modal_input (
215
+ seq_data , mm_data , computed_len )
216
+ multi_modal_inputs_list .append (mm_kwargs )
217
+
176
218
# Token position ids
177
219
# NOTE(woosuk): Here we assume that the first token in the prompt
178
220
# is always the first token in the sequence.
179
- input_positions .extend (list (range (computed_len , seq_len )))
180
- num_orig_input_tokens_list .extend ([seq_data .get_prompt_len ()] *
221
+ if mrope_positions :
222
+ for idx in range (3 ):
223
+ input_mrope_positions [idx ].extend (mrope_positions [idx ])
224
+ else :
225
+ input_positions .extend (list (range (computed_len , seq_len )))
226
+ num_orig_input_tokens_list .extend ([seq_data .get_prompt_len ()] *
181
227
(seq_len - computed_len ))
182
228
183
- if (mm_data := seq_group_metadata .multi_modal_data ):
184
- mm_kwargs = self .multi_modal_input_mapper (mm_data )
185
- multi_modal_inputs_list .append (mm_kwargs )
186
-
187
229
# Compute the slot mapping.
188
230
block_table = seq_group_metadata .block_tables [seq_id ]
189
231
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
@@ -206,12 +248,18 @@ def _prepare_prompt(
206
248
slot = block_number * self .block_size + block_offset
207
249
slot_mapping .append (slot )
208
250
251
+ if any (input_mrope_positions ):
252
+ input_positions = None # type: ignore
253
+ else :
254
+ input_mrope_positions = None # type: ignore
255
+
209
256
num_prompt_tokens = len (input_tokens )
210
257
211
258
input_tokens = torch .tensor (input_tokens ,
212
259
dtype = torch .long ,
213
260
device = self .device ) # type: ignore
214
- input_positions = torch .tensor (input_positions ,
261
+ input_positions = torch .tensor (input_positions
262
+ or input_mrope_positions ,
215
263
dtype = torch .long ,
216
264
device = self .device ) # type: ignore
217
265
num_orig_input_tokens_tensor = torch .tensor (
@@ -248,6 +296,7 @@ def _prepare_decode(
248
296
input_positions : List [int ] = []
249
297
# The number of original input tokens of each sequence
250
298
num_orig_input_tokens_list : List [int ] = []
299
+ input_mrope_positions : List [List [int ]] = [[] for _ in range (3 )]
251
300
slot_mapping : List [int ] = []
252
301
seq_lens : List [int ] = []
253
302
block_tables : List [List [int ]] = []
@@ -265,8 +314,18 @@ def _prepare_decode(
265
314
266
315
seq_len = seq_data .get_len ()
267
316
position = seq_len - 1
268
- input_positions .append (position )
269
- num_orig_input_tokens_list .append (seq_data .get_prompt_len ())
317
+ if seq_data .mrope_position_delta is not None :
318
+ context_len = seq_data .get_num_computed_tokens ()
319
+ next_pos = MRotaryEmbedding .get_next_input_positions (
320
+ seq_data .mrope_position_delta ,
321
+ context_len ,
322
+ seq_len ,
323
+ )
324
+ for idx in range (3 ):
325
+ input_mrope_positions [idx ].extend (next_pos [idx ])
326
+ else :
327
+ input_positions .append (position )
328
+ num_orig_input_tokens_list .append (seq_data .get_prompt_len ())
270
329
271
330
seq_len = seq_len if self .sliding_window is None else min (
272
331
seq_len , self .sliding_window )
@@ -284,12 +343,18 @@ def _prepare_decode(
284
343
block_table = block_table [- sliding_window_blocks :]
285
344
block_tables .append (block_table )
286
345
346
+ if any (input_mrope_positions ):
347
+ input_positions = None # type: ignore
348
+ else :
349
+ input_mrope_positions = None # type: ignore
350
+
287
351
max_decode_seq_len = max (seq_lens )
288
352
289
353
input_tokens = torch .tensor (input_tokens ,
290
354
dtype = torch .long ,
291
355
device = self .device )
292
- input_positions = torch .tensor (input_positions ,
356
+ input_positions = torch .tensor (input_positions
357
+ or input_mrope_positions ,
293
358
dtype = torch .long ,
294
359
device = self .device )
295
360
num_orig_input_tokens_tensor = torch .tensor (
@@ -388,6 +453,15 @@ def __init__(
388
453
raise NotImplementedError (
389
454
STR_NOT_IMPL_ENC_DEC_ERR_STRS ['STR_NOT_IMPL_ENC_DEC_CPU' ])
390
455
456
+ @property
457
+ def model_is_mrope (self ) -> bool :
458
+ """Detect if the model has "mrope" rope_scaling type.
459
+ mrope requires keep "rope_deltas" between prompt and decoding phases."""
460
+ rope_scaling = getattr (self .model_config .hf_config , "rope_scaling" , {})
461
+ if rope_scaling is None :
462
+ return False
463
+ return rope_scaling .get ("type" , None ) == "mrope"
464
+
391
465
def load_model (self ) -> None :
392
466
self .model = get_model (model_config = self .model_config ,
393
467
load_config = self .load_config ,
0 commit comments