12
12
SchedulerConfig )
13
13
from vllm .logger import init_logger
14
14
from vllm .model_executor import SamplingMetadata
15
+ from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
15
16
from vllm .model_executor .layers .sampler import SamplerOutput
16
17
from vllm .model_executor .model_loader import get_model
17
18
from vllm .multimodal import (MULTIMODAL_REGISTRY , BatchedTensorInputs ,
18
19
MultiModalInputs )
19
- from vllm .sequence import IntermediateTensors , SequenceGroupMetadata
20
+ from vllm .sequence import (IntermediateTensors , SequenceData ,
21
+ SequenceGroupMetadata )
20
22
from vllm .utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS , make_tensor_with_pad
21
23
from vllm .worker .model_runner_base import (
22
24
ModelRunnerBase , ModelRunnerInputBase , ModelRunnerInputBuilderBase ,
@@ -145,6 +147,38 @@ def build(self) -> ModelInputForCPU:
145
147
query_lens = seq_lens ,
146
148
)
147
149
150
+ def _compute_multi_modal_input (self , seq_data : SequenceData , mm_data ,
151
+ computed_len : int ):
152
+ mm_kwargs = self .multi_modal_input_mapper (mm_data )
153
+
154
+ # special processing for mrope position deltas.
155
+ mrope_positions = None
156
+ if self .runner .model_is_mrope :
157
+ image_grid_thw = mm_kwargs .get ("image_grid_thw" , None )
158
+ video_grid_thw = mm_kwargs .get ("video_grid_thw" , None )
159
+ assert image_grid_thw is not None or video_grid_thw is not None , (
160
+ "mrope embedding type requires multi-modal input mapper "
161
+ "returns 'image_grid_thw' or 'video_grid_thw'." )
162
+
163
+ hf_config = self .runner .model_config .hf_config
164
+ token_ids = seq_data .get_token_ids ()
165
+
166
+ mrope_positions , mrope_position_delta = \
167
+ MRotaryEmbedding .get_input_positions (
168
+ token_ids ,
169
+ image_grid_thw = image_grid_thw ,
170
+ video_grid_thw = video_grid_thw ,
171
+ image_token_id = hf_config .image_token_id ,
172
+ video_token_id = hf_config .video_token_id ,
173
+ vision_start_token_id = hf_config .vision_start_token_id ,
174
+ vision_end_token_id = hf_config .vision_end_token_id ,
175
+ spatial_merge_size = hf_config .vision_config .
176
+ spatial_merge_size ,
177
+ context_len = computed_len ,
178
+ )
179
+ seq_data .mrope_position_delta = mrope_position_delta
180
+ return mm_kwargs , mrope_positions
181
+
148
182
def _prepare_prompt (
149
183
self ,
150
184
seq_group_metadata_list : List [SequenceGroupMetadata ],
@@ -153,6 +187,8 @@ def _prepare_prompt(
153
187
assert len (seq_group_metadata_list ) > 0
154
188
input_tokens : List [int ] = []
155
189
input_positions : List [int ] = []
190
+ input_mrope_positions : List [List [int ]] = [[] for _ in range (3 )]
191
+
156
192
slot_mapping : List [int ] = []
157
193
seq_lens : List [int ] = []
158
194
multi_modal_inputs_list : List [MultiModalInputs ] = []
@@ -171,14 +207,20 @@ def _prepare_prompt(
171
207
seq_lens .append (seq_len ) # Prompt token num
172
208
input_tokens .extend (prompt_tokens ) # Token ids
173
209
210
+ mrope_positions = None
211
+ if (mm_data := seq_group_metadata .multi_modal_data ):
212
+ mm_kwargs , mrope_positions = self ._compute_multi_modal_input (
213
+ seq_data , mm_data , computed_len )
214
+ multi_modal_inputs_list .append (mm_kwargs )
215
+
174
216
# Token position ids
175
217
# NOTE(woosuk): Here we assume that the first token in the prompt
176
218
# is always the first token in the sequence.
177
- input_positions . extend ( list ( range ( computed_len , seq_len )))
178
-
179
- if ( mm_data := seq_group_metadata . multi_modal_data ):
180
- mm_kwargs = self . multi_modal_input_mapper ( mm_data )
181
- multi_modal_inputs_list . append ( mm_kwargs )
219
+ if mrope_positions :
220
+ for idx in range ( 3 ):
221
+ input_mrope_positions [ idx ]. extend ( mrope_positions [ idx ])
222
+ else :
223
+ input_positions . extend ( list ( range ( computed_len , seq_len )) )
182
224
183
225
# Compute the slot mapping.
184
226
block_table = seq_group_metadata .block_tables [seq_id ]
@@ -202,12 +244,18 @@ def _prepare_prompt(
202
244
slot = block_number * self .block_size + block_offset
203
245
slot_mapping .append (slot )
204
246
247
+ if any (input_mrope_positions ):
248
+ input_positions = None # type: ignore
249
+ else :
250
+ input_mrope_positions = None # type: ignore
251
+
205
252
num_prompt_tokens = len (input_tokens )
206
253
207
254
input_tokens = torch .tensor (input_tokens ,
208
255
dtype = torch .long ,
209
256
device = self .device ) # type: ignore
210
- input_positions = torch .tensor (input_positions ,
257
+ input_positions = torch .tensor (input_positions
258
+ or input_mrope_positions ,
211
259
dtype = torch .long ,
212
260
device = self .device ) # type: ignore
213
261
slot_mapping = torch .tensor (slot_mapping ,
@@ -238,6 +286,7 @@ def _prepare_decode(
238
286
assert len (seq_group_metadata_list ) > 0
239
287
input_tokens : List [int ] = []
240
288
input_positions : List [int ] = []
289
+ input_mrope_positions : List [List [int ]] = [[] for _ in range (3 )]
241
290
slot_mapping : List [int ] = []
242
291
seq_lens : List [int ] = []
243
292
block_tables : List [List [int ]] = []
@@ -255,7 +304,17 @@ def _prepare_decode(
255
304
256
305
seq_len = seq_data .get_len ()
257
306
position = seq_len - 1
258
- input_positions .append (position )
307
+ if seq_data .mrope_position_delta is not None :
308
+ context_len = seq_data .get_num_computed_tokens ()
309
+ next_pos = MRotaryEmbedding .get_next_input_positions (
310
+ seq_data .mrope_position_delta ,
311
+ context_len ,
312
+ seq_len ,
313
+ )
314
+ for idx in range (3 ):
315
+ input_mrope_positions [idx ].extend (next_pos [idx ])
316
+ else :
317
+ input_positions .append (position )
259
318
260
319
seq_len = seq_len if self .sliding_window is None else min (
261
320
seq_len , self .sliding_window )
@@ -273,12 +332,18 @@ def _prepare_decode(
273
332
block_table = block_table [- sliding_window_blocks :]
274
333
block_tables .append (block_table )
275
334
335
+ if any (input_mrope_positions ):
336
+ input_positions = None # type: ignore
337
+ else :
338
+ input_mrope_positions = None # type: ignore
339
+
276
340
max_decode_seq_len = max (seq_lens )
277
341
278
342
input_tokens = torch .tensor (input_tokens ,
279
343
dtype = torch .long ,
280
344
device = self .device )
281
- input_positions = torch .tensor (input_positions ,
345
+ input_positions = torch .tensor (input_positions
346
+ or input_mrope_positions ,
282
347
dtype = torch .long ,
283
348
device = self .device )
284
349
slot_mapping = torch .tensor (slot_mapping ,
@@ -373,6 +438,15 @@ def __init__(
373
438
raise NotImplementedError (
374
439
STR_NOT_IMPL_ENC_DEC_ERR_STRS ['STR_NOT_IMPL_ENC_DEC_CPU' ])
375
440
441
+ @property
442
+ def model_is_mrope (self ) -> bool :
443
+ """Detect if the model has "mrope" rope_scaling type.
444
+ mrope requires keep "rope_deltas" between prompt and decoding phases."""
445
+ rope_scaling = getattr (self .model_config .hf_config , "rope_scaling" , {})
446
+ if rope_scaling is None :
447
+ return False
448
+ return rope_scaling .get ("type" , None ) == "mrope"
449
+
376
450
def load_model (self ) -> None :
377
451
self .model = get_model (model_config = self .model_config ,
378
452
load_config = self .load_config ,
0 commit comments