Skip to content

Commit 3772037

Browse files
committed
fix batching inference
1 parent 12a7c6e commit 3772037

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

vllm/model_executor/models/qwen2_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,9 @@ def forward(
283283
"(b s) ... -> b s ...",
284284
b=batch_size)
285285
elif is_cpu():
286-
bs, seq_length, _, _ = q.shape
286+
seq_length = q.size(1)
287287
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
288-
attention_mask = torch.zeros([bs, 1, seq_length, seq_length],
288+
attention_mask = torch.zeros([1, seq_length, seq_length],
289289
device=q.device,
290290
dtype=torch.bool)
291291
for i in range(1, len(cu_seqlens)):

vllm/worker/cpu_model_runner.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def _prepare_prompt(
187187
assert len(seq_group_metadata_list) > 0
188188
input_tokens: List[int] = []
189189
input_positions: List[int] = []
190+
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
190191

191192
slot_mapping: List[int] = []
192193
seq_lens: List[int] = []
@@ -216,7 +217,8 @@ def _prepare_prompt(
216217
# NOTE(woosuk): Here we assume that the first token in the prompt
217218
# is always the first token in the sequence.
218219
if mrope_positions:
219-
input_positions.extend(mrope_positions)
220+
for idx in range(3):
221+
input_mrope_positions[idx].extend(mrope_positions[idx])
220222
else:
221223
input_positions.extend(list(range(computed_len, seq_len)))
222224

@@ -242,12 +244,18 @@ def _prepare_prompt(
242244
slot = block_number * self.block_size + block_offset
243245
slot_mapping.append(slot)
244246

247+
if any(input_mrope_positions):
248+
input_positions = None # type: ignore
249+
else:
250+
input_mrope_positions = None # type: ignore
251+
245252
num_prompt_tokens = len(input_tokens)
246253

247254
input_tokens = torch.tensor(input_tokens,
248255
dtype=torch.long,
249256
device=self.device) # type: ignore
250-
input_positions = torch.tensor(input_positions,
257+
input_positions = torch.tensor(input_positions
258+
or input_mrope_positions,
251259
dtype=torch.long,
252260
device=self.device) # type: ignore
253261
slot_mapping = torch.tensor(slot_mapping,
@@ -278,6 +286,7 @@ def _prepare_decode(
278286
assert len(seq_group_metadata_list) > 0
279287
input_tokens: List[int] = []
280288
input_positions: List[int] = []
289+
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
281290
slot_mapping: List[int] = []
282291
seq_lens: List[int] = []
283292
block_tables: List[List[int]] = []
@@ -302,7 +311,8 @@ def _prepare_decode(
302311
context_len,
303312
seq_len,
304313
)
305-
input_positions.extend(next_pos)
314+
for idx in range(3):
315+
input_mrope_positions[idx].extend(next_pos[idx])
306316
else:
307317
input_positions.append(position)
308318

@@ -322,12 +332,18 @@ def _prepare_decode(
322332
block_table = block_table[-sliding_window_blocks:]
323333
block_tables.append(block_table)
324334

335+
if any(input_mrope_positions):
336+
input_positions = None # type: ignore
337+
else:
338+
input_mrope_positions = None # type: ignore
339+
325340
max_decode_seq_len = max(seq_lens)
326341

327342
input_tokens = torch.tensor(input_tokens,
328343
dtype=torch.long,
329344
device=self.device)
330-
input_positions = torch.tensor(input_positions,
345+
input_positions = torch.tensor(input_positions
346+
or input_mrope_positions,
331347
dtype=torch.long,
332348
device=self.device)
333349
slot_mapping = torch.tensor(slot_mapping,

0 commit comments

Comments
 (0)