From 246a46c73befb101cc423c2e9229174a577a8e5a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 23 Sep 2024 15:18:08 +0800 Subject: [PATCH 1/7] redactor cpu_model_runner --- vllm/worker/cpu_model_runner.py | 297 ++++++++++++++++++++------------ 1 file changed, 191 insertions(+), 106 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7b2caf49735..07fdbb580ee 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,3 +1,5 @@ +import dataclasses +import weakref from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -17,7 +19,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, @@ -32,16 +34,17 @@ @dataclass(frozen=True) -class CPUModelInput(ModelRunnerInputBase): +class ModelInputForCPU(ModelRunnerInputBase): """ Used by the CPUModelRunner. """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None - sampling_metadata: Optional["SamplingMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -51,88 +54,96 @@ def as_broadcastable_tensor_dict( "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) + return tensor_dict @classmethod def from_broadcasted_tensor_dict( - cls: Type["CPUModelInput"], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None - ) -> "CPUModelInput": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + cls: Type["ModelInputForCPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None + ) -> "ModelInputForCPU": if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) -class CPUModelRunner(ModelRunnerBase[CPUModelInput]): +@dataclass(frozen=True) +class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - is_driver_worker: bool = False, - *args, - **kwargs, - ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - # Currently, CPU worker doesn't support chunked prefill. - assert self.scheduler_config.chunked_prefill_enabled is False - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config - self.load_config = load_config - self.is_driver_worker = is_driver_worker + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict - self.device = self.device_config.device + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForCPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), - self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), - self.model_config.get_sliding_window(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - ) - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) +class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): - # Lazy initialization. - self.model: nn.Module # Set after init_Model + def __init__(self, + runner: "CPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper - if self.model_config.is_encoder_decoder_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) - def load_model(self) -> None: - self.model = get_model(model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + def build(self) -> ModelInputForCPU: + multi_modal_kwargs = None + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = [] + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + # query_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens=seq_lens, + query_lens=seq_lens, + ) def _prepare_prompt( self, @@ -302,56 +313,130 @@ def _prepare_decode( attn_metadata, ) + +class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): + _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( + ModelInputForCPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + # Currently, CPU worker doesn't support chunked prefill. + assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ + .create_input_mapper(self.model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + if self.model_config.is_encoder_decoder_model: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + + def load_model(self) -> None: + self.model = get_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], - ) -> CPUModelInput: - return CPUModelInput.from_broadcasted_tensor_dict( + ) -> ModelInputForCPU: + return ModelInputForCPU.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + + return builder.build() # type: ignore + def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> CPUModelInput: - multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - pin_memory=False, - generators=self.get_generators(finished_requests_ids)) - return CPUModelInput( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs, - ) + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine) @torch.no_grad() def execute_model( self, - model_input: CPUModelInput, + model_input: ModelInputForCPUWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, From 8fec49e543cde25773ad7ca5b486a7729946d62a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 23 Sep 2024 15:44:28 +0800 Subject: [PATCH 2/7] optimize docs --- vllm/worker/cpu_model_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 07fdbb580ee..b7002e75c9e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -36,7 +36,7 @@ @dataclass(frozen=True) class ModelInputForCPU(ModelRunnerInputBase): """ - Used by the CPUModelRunner. + Base class contains metadata needed for the base model forward pass on CPU """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None @@ -176,8 +176,7 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: + if (mm_data := seq_group_metadata.multi_modal_data): mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) From dd9bd6f21b0596515ecbd2aac7fbb761f40c72cf Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 23 Sep 2024 21:01:00 +0800 Subject: [PATCH 3/7] support qwen2-vl on CPU --- vllm/model_executor/models/qwen2_vl.py | 9 ++++++ vllm/worker/cpu_model_runner.py | 42 ++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1011c925679..46bb626f007 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -67,6 +67,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import get_processor +from vllm.utils import is_cpu logger = init_logger(__name__) @@ -278,6 +279,14 @@ def forward( context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + elif is_cpu(): + seq_length = q.size(1) + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + context_layer = rearrange(output, "b h s d -> b s h d ") else: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b7002e75c9e..008bed837f5 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -12,6 +12,7 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, @@ -153,6 +154,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + mrope_input_positions: List[List[int]] = [] + slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] @@ -179,6 +182,33 @@ def _prepare_prompt( if (mm_data := seq_group_metadata.multi_modal_data): mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) + + # special processing for mrope position deltas. + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=computed_len, + ) + seq_data.mrope_position_delta = mrope_position_delta + mrope_input_positions.extend(mrope_positions) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] @@ -202,6 +232,9 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if mrope_input_positions: + input_positions = mrope_input_positions + num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, @@ -432,6 +465,15 @@ def prepare_model_input( sampling_metadata=sampling_metadata, virtual_engine=virtual_engine) + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + @torch.no_grad() def execute_model( self, From 6bb4230601f862101ec05b5a0cf47ba8c4d32203 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 24 Sep 2024 16:35:27 +0800 Subject: [PATCH 4/7] mrope decoding --- vllm/model_executor/models/qwen2_vl.py | 15 ++++++++--- vllm/worker/cpu_model_runner.py | 35 ++++++++++++++++++-------- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 46bb626f007..af8927f3128 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -280,12 +280,19 @@ def forward( "(b s) ... -> b s ...", b=batch_size) elif is_cpu(): - seq_length = q.size(1) + bs, seq_length, _, _ = q.shape q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + attention_mask = torch.zeros([bs, 1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = True + output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) context_layer = rearrange(output, "b h s d -> b s h d ") else: from xformers import ops as xops diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 008bed837f5..af19a100f22 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -182,7 +182,7 @@ def _prepare_prompt( if (mm_data := seq_group_metadata.multi_modal_data): mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) - + # special processing for mrope position deltas. if self.runner.model_is_mrope: image_grid_thw = mm_kwargs.get("image_grid_thw", None) @@ -232,17 +232,19 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - if mrope_input_positions: - input_positions = mrope_input_positions - num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) # type: ignore + if mrope_input_positions: + input_positions = torch.tensor(mrope_input_positions, + dtype=torch.long, + device=self.device) + else: + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore @@ -271,6 +273,7 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + mrope_input_positions: List[List[int]] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -289,6 +292,13 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append(position) + if seq_data.mrope_position_delta is not None: + next_mrope_positions = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + position, + seq_len, + ) + mrope_input_positions.extend(next_mrope_positions) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -311,9 +321,14 @@ def _prepare_decode( input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + if mrope_input_positions: + input_positions = torch.tensor(mrope_input_positions, + dtype=torch.long, + device=self.device) + else: + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) From aa73afdbe75b50977e970ea9fbe39ad36031eba0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 24 Sep 2024 17:41:26 +0800 Subject: [PATCH 5/7] refactor --- vllm/worker/cpu_model_runner.py | 111 ++++++++++++++++---------------- 1 file changed, 56 insertions(+), 55 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index af19a100f22..ec46a238a5a 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -17,7 +17,8 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -146,6 +147,38 @@ def build(self) -> ModelInputForCPU: query_lens=seq_lens, ) + def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, + computed_len: int): + mm_kwargs = self.multi_modal_input_mapper(mm_data) + + # special processing for mrope position deltas. + mrope_positions = None + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=computed_len, + ) + seq_data.mrope_position_delta = mrope_position_delta + return mm_kwargs, mrope_positions + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -154,7 +187,6 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] - mrope_input_positions: List[List[int]] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] @@ -174,41 +206,19 @@ def _prepare_prompt( seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids - # Token position ids - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) - + mrope_positions = None if (mm_data := seq_group_metadata.multi_modal_data): - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs, mrope_positions = self._compute_multi_modal_input( + seq_data, mm_data, computed_len) multi_modal_inputs_list.append(mm_kwargs) - # special processing for mrope position deltas. - if self.runner.model_is_mrope: - image_grid_thw = mm_kwargs.get("image_grid_thw", None) - video_grid_thw = mm_kwargs.get("video_grid_thw", None) - assert image_grid_thw is not None or video_grid_thw is not None, ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw'.") - - hf_config = self.runner.model_config.hf_config - token_ids = seq_data.get_token_ids() - - mrope_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( - token_ids, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - image_token_id=hf_config.image_token_id, - video_token_id=hf_config.video_token_id, - vision_start_token_id=hf_config.vision_start_token_id, - vision_end_token_id=hf_config.vision_end_token_id, - spatial_merge_size=hf_config.vision_config. - spatial_merge_size, - context_len=computed_len, - ) - seq_data.mrope_position_delta = mrope_position_delta - mrope_input_positions.extend(mrope_positions) + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + if mrope_positions: + input_positions.extend(mrope_positions) + else: + input_positions.extend(list(range(computed_len, seq_len))) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] @@ -237,14 +247,9 @@ def _prepare_prompt( input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore - if mrope_input_positions: - input_positions = torch.tensor(mrope_input_positions, - dtype=torch.long, - device=self.device) - else: - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore @@ -273,7 +278,6 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] - mrope_input_positions: List[List[int]] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -291,14 +295,16 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append(position) if seq_data.mrope_position_delta is not None: - next_mrope_positions = MRotaryEmbedding.get_next_input_positions( + context_len = seq_data.get_num_computed_tokens() + next_pos = MRotaryEmbedding.get_next_input_positions( seq_data.mrope_position_delta, - position, + context_len, seq_len, ) - mrope_input_positions.extend(next_mrope_positions) + input_positions.extend(next_pos) + else: + input_positions.append(position) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -321,14 +327,9 @@ def _prepare_decode( input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) - if mrope_input_positions: - input_positions = torch.tensor(mrope_input_positions, - dtype=torch.long, - device=self.device) - else: - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) From 12a7c6ea2b5991888ba05f8049680e082ca121fb Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 24 Sep 2024 18:08:28 +0800 Subject: [PATCH 6/7] add model_is_mrope --- vllm/worker/cpu_model_runner.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 0907652ef3e..f005ff3d409 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -422,6 +422,15 @@ def __init__( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config, From 3772037cb456e74534ece2aa2518b18c7e8ca804 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 25 Sep 2024 10:58:03 +0800 Subject: [PATCH 7/7] fix batching inference --- vllm/model_executor/models/qwen2_vl.py | 4 ++-- vllm/worker/cpu_model_runner.py | 24 ++++++++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 71431a7d5c6..889ebc6c2e1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -283,9 +283,9 @@ def forward( "(b s) ... -> b s ...", b=batch_size) elif is_cpu(): - bs, seq_length, _, _ = q.shape + seq_length = q.size(1) q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] - attention_mask = torch.zeros([bs, 1, seq_length, seq_length], + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f005ff3d409..cebb0f36a2b 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -187,6 +187,7 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] slot_mapping: List[int] = [] seq_lens: List[int] = [] @@ -216,7 +217,8 @@ def _prepare_prompt( # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. if mrope_positions: - input_positions.extend(mrope_positions) + for idx in range(3): + input_mrope_positions[idx].extend(mrope_positions[idx]) else: input_positions.extend(list(range(computed_len, seq_len))) @@ -242,12 +244,18 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore - input_positions = torch.tensor(input_positions, + input_positions = torch.tensor(input_positions + or input_mrope_positions, dtype=torch.long, device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, @@ -278,6 +286,7 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -302,7 +311,8 @@ def _prepare_decode( context_len, seq_len, ) - input_positions.extend(next_pos) + for idx in range(3): + input_mrope_positions[idx].extend(next_pos[idx]) else: input_positions.append(position) @@ -322,12 +332,18 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) - input_positions = torch.tensor(input_positions, + input_positions = torch.tensor(input_positions + or input_mrope_positions, dtype=torch.long, device=self.device) slot_mapping = torch.tensor(slot_mapping,