1
1
import time
2
- from typing import List , Mapping , Optional , Tuple
2
+ from typing import List , Optional , Tuple
3
3
4
4
import numpy as np
5
5
import torch
12
12
from vllm .logger import init_logger
13
13
from vllm .model_executor .model_loader import get_model
14
14
from vllm .model_executor .sampling_metadata import SamplingMetadata
15
- from vllm .multimodal import (MULTIMODAL_REGISTRY , BatchedTensors ,
16
- MultiModalInputs )
17
15
from vllm .sequence import (CompletionSequenceGroupOutput , Logprob ,
18
16
SamplerOutput , SequenceGroupMetadata ,
19
17
SequenceOutput )
@@ -68,10 +66,6 @@ def __init__(
68
66
False ,
69
67
)
70
68
71
- # Multi-modal data support
72
- self .multi_modal_input_mapper = MULTIMODAL_REGISTRY \
73
- .create_input_mapper (self .model_config )
74
-
75
69
def load_model (self ) -> None :
76
70
self .device = self .device_config .device
77
71
@@ -154,7 +148,7 @@ def _dummy_run(
154
148
# Dummy run.
155
149
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
156
150
self .model (token_ids , position_ids , kv_caches , attn_metadata ,
157
- input_lens , None , t , p , num_samples )
151
+ input_lens , t , p , num_samples )
158
152
159
153
def warmup_model (
160
154
self ,
@@ -199,14 +193,12 @@ def warmup_model(
199
193
def _prepare_prompt (
200
194
self ,
201
195
seq_group_metadata_list : List [SequenceGroupMetadata ],
202
- ) -> Tuple [torch .Tensor , torch .Tensor , AttentionMetadata , torch .Tensor ,
203
- Mapping [str , BatchedTensors ]]:
196
+ ) -> Tuple [torch .Tensor , torch .Tensor , AttentionMetadata , torch .Tensor ]:
204
197
assert len (seq_group_metadata_list ) > 0
205
198
input_tokens : List [List [int ]] = []
206
199
input_positions : List [List [int ]] = []
207
200
prompt_lens : List [int ] = []
208
201
slot_mapping : List [List [int ]] = []
209
- multi_modal_inputs_list : List [MultiModalInputs ] = []
210
202
211
203
for seq_group_metadata in seq_group_metadata_list :
212
204
assert seq_group_metadata .is_prompt
@@ -232,11 +224,6 @@ def _prepare_prompt(
232
224
slot = block_number * self .block_size + block_offset
233
225
slot_mapping [- 1 ].append (slot )
234
226
235
- mm_data = seq_group_metadata .multi_modal_data
236
- if mm_data :
237
- mm_kwargs = self .multi_modal_input_mapper (mm_data )
238
- multi_modal_inputs_list .append (mm_kwargs )
239
-
240
227
assert len (prompt_lens ) > 0
241
228
num_prefills = len (prompt_lens )
242
229
num_prefill_tokens = sum (prompt_lens )
@@ -274,24 +261,17 @@ def _prepare_prompt(
274
261
block_tables = None ,
275
262
context_lens = None ,
276
263
)
277
-
278
- multi_modal_kwargs = MultiModalInputs .batch (multi_modal_inputs_list ,
279
- device = self .device )
280
-
281
- return (input_tokens , input_positions , attn_metadata , prompt_lens ,
282
- multi_modal_kwargs )
264
+ return input_tokens , input_positions , attn_metadata , prompt_lens
283
265
284
266
def _prepare_decode (
285
267
self ,
286
268
seq_group_metadata_list : List [SequenceGroupMetadata ],
287
- ) -> Tuple [torch .Tensor , torch .Tensor , AttentionMetadata , torch .Tensor ,
288
- Mapping [str , BatchedTensors ]]:
269
+ ) -> Tuple [torch .Tensor , torch .Tensor , AttentionMetadata , torch .Tensor ]:
289
270
assert len (seq_group_metadata_list ) > 0
290
271
input_tokens : List [List [int ]] = []
291
272
input_positions : List [List [int ]] = []
292
273
slot_mapping : List [List [int ]] = []
293
274
context_lens : List [int ] = []
294
- multi_modal_inputs_list : List [MultiModalInputs ] = []
295
275
296
276
batch_idx = 0
297
277
for seq_group_metadata in seq_group_metadata_list :
@@ -317,11 +297,6 @@ def _prepare_decode(
317
297
slot = block_number * self .block_size + block_offset
318
298
slot_mapping .append ([slot ])
319
299
320
- mm_data = seq_group_metadata .multi_modal_data
321
- if mm_data :
322
- mm_kwargs = self .multi_modal_input_mapper (mm_data )
323
- multi_modal_inputs_list .append (mm_kwargs )
324
-
325
300
batch_size = _get_padded_batch_size (batch_idx )
326
301
num_paddings = batch_size - batch_idx
327
302
input_tokens = input_tokens + [[0 ]] * num_paddings
@@ -355,12 +330,7 @@ def _prepare_decode(
355
330
block_tables = block_tables ,
356
331
context_lens = context_lens ,
357
332
)
358
-
359
- multi_modal_kwargs = MultiModalInputs .batch (multi_modal_inputs_list ,
360
- device = self .device )
361
-
362
- return (input_tokens , input_positions , attn_metadata , input_lens ,
363
- multi_modal_kwargs )
333
+ return input_tokens , input_positions , attn_metadata , input_lens
364
334
365
335
def _prepare_sample (
366
336
self ,
@@ -513,7 +483,6 @@ def forward(
513
483
kv_caches : List [Tuple [Optional [torch .Tensor ], Optional [torch .Tensor ]]],
514
484
attn_metadata : AttentionMetadata ,
515
485
input_lens : torch .Tensor ,
516
- multi_modal_kwargs : Optional [Mapping [str , BatchedTensors ]],
517
486
t : torch .Tensor ,
518
487
p : torch .Tensor ,
519
488
num_samples : int ,
@@ -527,8 +496,6 @@ def forward(
527
496
memory profiling at initialization.
528
497
attn_metadata: The Pallas attention metadata.
529
498
input_lens: The actual input lengths of shape [batch_size].
530
- multi_modal_kwargs: Keyword arguments from multi-modal data to
531
- pass to the model.
532
499
t: The sampling temperature of shape [batch_size].
533
500
p: The top-p probability of shape [batch_size].
534
501
"""
@@ -573,7 +540,6 @@ def forward(
573
540
position_ids ,
574
541
kv_caches ,
575
542
attn_metadata ,
576
- ** (multi_modal_kwargs or {}),
577
543
)
578
544
hidden_states = hidden_states .flatten (0 , 1 )
579
545
logits = self .model .compute_logits (hidden_states , sampling_metadata )
0 commit comments