@@ -187,6 +187,7 @@ def _prepare_prompt(
187
187
assert len (seq_group_metadata_list ) > 0
188
188
input_tokens : List [int ] = []
189
189
input_positions : List [int ] = []
190
+ input_mrope_positions : List [List [int ]] = [[] for _ in range (3 )]
190
191
191
192
slot_mapping : List [int ] = []
192
193
seq_lens : List [int ] = []
@@ -216,7 +217,8 @@ def _prepare_prompt(
216
217
# NOTE(woosuk): Here we assume that the first token in the prompt
217
218
# is always the first token in the sequence.
218
219
if mrope_positions :
219
- input_positions .extend (mrope_positions )
220
+ for idx in range (3 ):
221
+ input_mrope_positions [idx ].extend (mrope_positions [idx ])
220
222
else :
221
223
input_positions .extend (list (range (computed_len , seq_len )))
222
224
@@ -242,12 +244,18 @@ def _prepare_prompt(
242
244
slot = block_number * self .block_size + block_offset
243
245
slot_mapping .append (slot )
244
246
247
+ if any (input_mrope_positions ):
248
+ input_positions = None # type: ignore
249
+ else :
250
+ input_mrope_positions = None # type: ignore
251
+
245
252
num_prompt_tokens = len (input_tokens )
246
253
247
254
input_tokens = torch .tensor (input_tokens ,
248
255
dtype = torch .long ,
249
256
device = self .device ) # type: ignore
250
- input_positions = torch .tensor (input_positions ,
257
+ input_positions = torch .tensor (input_positions
258
+ or input_mrope_positions ,
251
259
dtype = torch .long ,
252
260
device = self .device ) # type: ignore
253
261
slot_mapping = torch .tensor (slot_mapping ,
@@ -278,6 +286,7 @@ def _prepare_decode(
278
286
assert len (seq_group_metadata_list ) > 0
279
287
input_tokens : List [int ] = []
280
288
input_positions : List [int ] = []
289
+ input_mrope_positions : List [List [int ]] = [[] for _ in range (3 )]
281
290
slot_mapping : List [int ] = []
282
291
seq_lens : List [int ] = []
283
292
block_tables : List [List [int ]] = []
@@ -302,7 +311,8 @@ def _prepare_decode(
302
311
context_len ,
303
312
seq_len ,
304
313
)
305
- input_positions .extend (next_pos )
314
+ for idx in range (3 ):
315
+ input_mrope_positions [idx ].extend (next_pos [idx ])
306
316
else :
307
317
input_positions .append (position )
308
318
@@ -322,12 +332,18 @@ def _prepare_decode(
322
332
block_table = block_table [- sliding_window_blocks :]
323
333
block_tables .append (block_table )
324
334
335
+ if any (input_mrope_positions ):
336
+ input_positions = None # type: ignore
337
+ else :
338
+ input_mrope_positions = None # type: ignore
339
+
325
340
max_decode_seq_len = max (seq_lens )
326
341
327
342
input_tokens = torch .tensor (input_tokens ,
328
343
dtype = torch .long ,
329
344
device = self .device )
330
- input_positions = torch .tensor (input_positions ,
345
+ input_positions = torch .tensor (input_positions
346
+ or input_mrope_positions ,
331
347
dtype = torch .long ,
332
348
device = self .device )
333
349
slot_mapping = torch .tensor (slot_mapping ,
0 commit comments