Skip to content

Commit a6500d2

Browse files
WoosukKwonjimpang
authored andcommitted
[TPU] Remove multi-modal args in TPU backend (vllm-project#6504)
1 parent 8adf32a commit a6500d2

File tree

1 file changed

+6
-40
lines changed

1 file changed

+6
-40
lines changed

vllm/worker/tpu_model_runner.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import List, Mapping, Optional, Tuple
2+
from typing import List, Optional, Tuple
33

44
import numpy as np
55
import torch
@@ -12,8 +12,6 @@
1212
from vllm.logger import init_logger
1313
from vllm.model_executor.model_loader import get_model
1414
from vllm.model_executor.sampling_metadata import SamplingMetadata
15-
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
16-
MultiModalInputs)
1715
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
1816
SamplerOutput, SequenceGroupMetadata,
1917
SequenceOutput)
@@ -68,10 +66,6 @@ def __init__(
6866
False,
6967
)
7068

71-
# Multi-modal data support
72-
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
73-
.create_input_mapper(self.model_config)
74-
7569
def load_model(self) -> None:
7670
self.device = self.device_config.device
7771

@@ -154,7 +148,7 @@ def _dummy_run(
154148
# Dummy run.
155149
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
156150
self.model(token_ids, position_ids, kv_caches, attn_metadata,
157-
input_lens, None, t, p, num_samples)
151+
input_lens, t, p, num_samples)
158152

159153
def warmup_model(
160154
self,
@@ -199,14 +193,12 @@ def warmup_model(
199193
def _prepare_prompt(
200194
self,
201195
seq_group_metadata_list: List[SequenceGroupMetadata],
202-
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
203-
Mapping[str, BatchedTensors]]:
196+
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
204197
assert len(seq_group_metadata_list) > 0
205198
input_tokens: List[List[int]] = []
206199
input_positions: List[List[int]] = []
207200
prompt_lens: List[int] = []
208201
slot_mapping: List[List[int]] = []
209-
multi_modal_inputs_list: List[MultiModalInputs] = []
210202

211203
for seq_group_metadata in seq_group_metadata_list:
212204
assert seq_group_metadata.is_prompt
@@ -232,11 +224,6 @@ def _prepare_prompt(
232224
slot = block_number * self.block_size + block_offset
233225
slot_mapping[-1].append(slot)
234226

235-
mm_data = seq_group_metadata.multi_modal_data
236-
if mm_data:
237-
mm_kwargs = self.multi_modal_input_mapper(mm_data)
238-
multi_modal_inputs_list.append(mm_kwargs)
239-
240227
assert len(prompt_lens) > 0
241228
num_prefills = len(prompt_lens)
242229
num_prefill_tokens = sum(prompt_lens)
@@ -274,24 +261,17 @@ def _prepare_prompt(
274261
block_tables=None,
275262
context_lens=None,
276263
)
277-
278-
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
279-
device=self.device)
280-
281-
return (input_tokens, input_positions, attn_metadata, prompt_lens,
282-
multi_modal_kwargs)
264+
return input_tokens, input_positions, attn_metadata, prompt_lens
283265

284266
def _prepare_decode(
285267
self,
286268
seq_group_metadata_list: List[SequenceGroupMetadata],
287-
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
288-
Mapping[str, BatchedTensors]]:
269+
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
289270
assert len(seq_group_metadata_list) > 0
290271
input_tokens: List[List[int]] = []
291272
input_positions: List[List[int]] = []
292273
slot_mapping: List[List[int]] = []
293274
context_lens: List[int] = []
294-
multi_modal_inputs_list: List[MultiModalInputs] = []
295275

296276
batch_idx = 0
297277
for seq_group_metadata in seq_group_metadata_list:
@@ -317,11 +297,6 @@ def _prepare_decode(
317297
slot = block_number * self.block_size + block_offset
318298
slot_mapping.append([slot])
319299

320-
mm_data = seq_group_metadata.multi_modal_data
321-
if mm_data:
322-
mm_kwargs = self.multi_modal_input_mapper(mm_data)
323-
multi_modal_inputs_list.append(mm_kwargs)
324-
325300
batch_size = _get_padded_batch_size(batch_idx)
326301
num_paddings = batch_size - batch_idx
327302
input_tokens = input_tokens + [[0]] * num_paddings
@@ -355,12 +330,7 @@ def _prepare_decode(
355330
block_tables=block_tables,
356331
context_lens=context_lens,
357332
)
358-
359-
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
360-
device=self.device)
361-
362-
return (input_tokens, input_positions, attn_metadata, input_lens,
363-
multi_modal_kwargs)
333+
return input_tokens, input_positions, attn_metadata, input_lens
364334

365335
def _prepare_sample(
366336
self,
@@ -513,7 +483,6 @@ def forward(
513483
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
514484
attn_metadata: AttentionMetadata,
515485
input_lens: torch.Tensor,
516-
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
517486
t: torch.Tensor,
518487
p: torch.Tensor,
519488
num_samples: int,
@@ -527,8 +496,6 @@ def forward(
527496
memory profiling at initialization.
528497
attn_metadata: The Pallas attention metadata.
529498
input_lens: The actual input lengths of shape [batch_size].
530-
multi_modal_kwargs: Keyword arguments from multi-modal data to
531-
pass to the model.
532499
t: The sampling temperature of shape [batch_size].
533500
p: The top-p probability of shape [batch_size].
534501
"""
@@ -573,7 +540,6 @@ def forward(
573540
position_ids,
574541
kv_caches,
575542
attn_metadata,
576-
**(multi_modal_kwargs or {}),
577543
)
578544
hidden_states = hidden_states.flatten(0, 1)
579545
logits = self.model.compute_logits(hidden_states, sampling_metadata)

0 commit comments

Comments
 (0)