Skip to content

Commit d214697

Browse files
Isotr0pygarg-amit
authored andcommitted
[Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend (vllm-project#8770)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent cc982dc commit d214697

File tree

2 files changed

+101
-11
lines changed

2 files changed

+101
-11
lines changed

vllm/model_executor/models/qwen2_vl.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from vllm.platforms import current_platform
6868
from vllm.sequence import IntermediateTensors, SequenceData
6969
from vllm.transformers_utils.processor import get_processor
70+
from vllm.utils import is_cpu
7071

7172
from .utils import (PPMissingLayer, is_pp_missing_parameter,
7273
make_empty_intermediate_tensors_factory)
@@ -281,6 +282,21 @@ def forward(
281282
context_layer = rearrange(output,
282283
"(b s) ... -> b s ...",
283284
b=batch_size)
285+
elif is_cpu():
286+
seq_length = q.size(1)
287+
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
288+
attention_mask = torch.zeros([1, seq_length, seq_length],
289+
device=q.device,
290+
dtype=torch.bool)
291+
for i in range(1, len(cu_seqlens)):
292+
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
293+
cu_seqlens[i - 1]:cu_seqlens[i]] = True
294+
output = F.scaled_dot_product_attention(q,
295+
k,
296+
v,
297+
attention_mask,
298+
dropout_p=0.0)
299+
context_layer = rearrange(output, "b h s d -> b s h d ")
284300
else:
285301
from xformers import ops as xops
286302
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

vllm/worker/cpu_model_runner.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
SchedulerConfig)
1313
from vllm.logger import init_logger
1414
from vllm.model_executor import SamplingMetadata
15+
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
1516
from vllm.model_executor.layers.sampler import SamplerOutput
1617
from vllm.model_executor.model_loader import get_model
1718
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
1819
MultiModalInputs)
19-
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
20+
from vllm.sequence import (IntermediateTensors, SequenceData,
21+
SequenceGroupMetadata)
2022
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
2123
from vllm.worker.model_runner_base import (
2224
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
@@ -145,6 +147,38 @@ def build(self) -> ModelInputForCPU:
145147
query_lens=seq_lens,
146148
)
147149

150+
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
151+
computed_len: int):
152+
mm_kwargs = self.multi_modal_input_mapper(mm_data)
153+
154+
# special processing for mrope position deltas.
155+
mrope_positions = None
156+
if self.runner.model_is_mrope:
157+
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
158+
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
159+
assert image_grid_thw is not None or video_grid_thw is not None, (
160+
"mrope embedding type requires multi-modal input mapper "
161+
"returns 'image_grid_thw' or 'video_grid_thw'.")
162+
163+
hf_config = self.runner.model_config.hf_config
164+
token_ids = seq_data.get_token_ids()
165+
166+
mrope_positions, mrope_position_delta = \
167+
MRotaryEmbedding.get_input_positions(
168+
token_ids,
169+
image_grid_thw=image_grid_thw,
170+
video_grid_thw=video_grid_thw,
171+
image_token_id=hf_config.image_token_id,
172+
video_token_id=hf_config.video_token_id,
173+
vision_start_token_id=hf_config.vision_start_token_id,
174+
vision_end_token_id=hf_config.vision_end_token_id,
175+
spatial_merge_size=hf_config.vision_config.
176+
spatial_merge_size,
177+
context_len=computed_len,
178+
)
179+
seq_data.mrope_position_delta = mrope_position_delta
180+
return mm_kwargs, mrope_positions
181+
148182
def _prepare_prompt(
149183
self,
150184
seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -155,6 +189,8 @@ def _prepare_prompt(
155189
input_positions: List[int] = []
156190
# The number of original input tokens of each sequence
157191
num_orig_input_tokens_list: List[int] = []
192+
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
193+
158194
slot_mapping: List[int] = []
159195
seq_lens: List[int] = []
160196
multi_modal_inputs_list: List[MultiModalInputs] = []
@@ -173,17 +209,23 @@ def _prepare_prompt(
173209
seq_lens.append(seq_len) # Prompt token num
174210
input_tokens.extend(prompt_tokens) # Token ids
175211

212+
mrope_positions = None
213+
if (mm_data := seq_group_metadata.multi_modal_data):
214+
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
215+
seq_data, mm_data, computed_len)
216+
multi_modal_inputs_list.append(mm_kwargs)
217+
176218
# Token position ids
177219
# NOTE(woosuk): Here we assume that the first token in the prompt
178220
# is always the first token in the sequence.
179-
input_positions.extend(list(range(computed_len, seq_len)))
180-
num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] *
221+
if mrope_positions:
222+
for idx in range(3):
223+
input_mrope_positions[idx].extend(mrope_positions[idx])
224+
else:
225+
input_positions.extend(list(range(computed_len, seq_len)))
226+
num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] *
181227
(seq_len - computed_len))
182228

183-
if (mm_data := seq_group_metadata.multi_modal_data):
184-
mm_kwargs = self.multi_modal_input_mapper(mm_data)
185-
multi_modal_inputs_list.append(mm_kwargs)
186-
187229
# Compute the slot mapping.
188230
block_table = seq_group_metadata.block_tables[seq_id]
189231
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
@@ -206,12 +248,18 @@ def _prepare_prompt(
206248
slot = block_number * self.block_size + block_offset
207249
slot_mapping.append(slot)
208250

251+
if any(input_mrope_positions):
252+
input_positions = None # type: ignore
253+
else:
254+
input_mrope_positions = None # type: ignore
255+
209256
num_prompt_tokens = len(input_tokens)
210257

211258
input_tokens = torch.tensor(input_tokens,
212259
dtype=torch.long,
213260
device=self.device) # type: ignore
214-
input_positions = torch.tensor(input_positions,
261+
input_positions = torch.tensor(input_positions
262+
or input_mrope_positions,
215263
dtype=torch.long,
216264
device=self.device) # type: ignore
217265
num_orig_input_tokens_tensor = torch.tensor(
@@ -248,6 +296,7 @@ def _prepare_decode(
248296
input_positions: List[int] = []
249297
# The number of original input tokens of each sequence
250298
num_orig_input_tokens_list: List[int] = []
299+
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
251300
slot_mapping: List[int] = []
252301
seq_lens: List[int] = []
253302
block_tables: List[List[int]] = []
@@ -265,8 +314,18 @@ def _prepare_decode(
265314

266315
seq_len = seq_data.get_len()
267316
position = seq_len - 1
268-
input_positions.append(position)
269-
num_orig_input_tokens_list.append(seq_data.get_prompt_len())
317+
if seq_data.mrope_position_delta is not None:
318+
context_len = seq_data.get_num_computed_tokens()
319+
next_pos = MRotaryEmbedding.get_next_input_positions(
320+
seq_data.mrope_position_delta,
321+
context_len,
322+
seq_len,
323+
)
324+
for idx in range(3):
325+
input_mrope_positions[idx].extend(next_pos[idx])
326+
else:
327+
input_positions.append(position)
328+
num_orig_input_tokens_list.append(seq_data.get_prompt_len())
270329

271330
seq_len = seq_len if self.sliding_window is None else min(
272331
seq_len, self.sliding_window)
@@ -284,12 +343,18 @@ def _prepare_decode(
284343
block_table = block_table[-sliding_window_blocks:]
285344
block_tables.append(block_table)
286345

346+
if any(input_mrope_positions):
347+
input_positions = None # type: ignore
348+
else:
349+
input_mrope_positions = None # type: ignore
350+
287351
max_decode_seq_len = max(seq_lens)
288352

289353
input_tokens = torch.tensor(input_tokens,
290354
dtype=torch.long,
291355
device=self.device)
292-
input_positions = torch.tensor(input_positions,
356+
input_positions = torch.tensor(input_positions
357+
or input_mrope_positions,
293358
dtype=torch.long,
294359
device=self.device)
295360
num_orig_input_tokens_tensor = torch.tensor(
@@ -388,6 +453,15 @@ def __init__(
388453
raise NotImplementedError(
389454
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
390455

456+
@property
457+
def model_is_mrope(self) -> bool:
458+
"""Detect if the model has "mrope" rope_scaling type.
459+
mrope requires keep "rope_deltas" between prompt and decoding phases."""
460+
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
461+
if rope_scaling is None:
462+
return False
463+
return rope_scaling.get("type", None) == "mrope"
464+
391465
def load_model(self) -> None:
392466
self.model = get_model(model_config=self.model_config,
393467
load_config=self.load_config,

0 commit comments

Comments
 (0)