From 6ba5346abe7e5d33c817e4a8faba4b6f4f6803d7 Mon Sep 17 00:00:00 2001 From: ywfang <1157670798@qq.com> Date: Mon, 9 Sep 2024 21:36:09 +0800 Subject: [PATCH 01/11] support minicpm3 --- docs/source/models/supported_models.rst | 4 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/minicpm3.py | 604 ++++++++++++++++++++++++ 3 files changed, 609 insertions(+) create mode 100644 vllm/model_executor/models/minicpm3.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 1bb3a448f2c..624b312b282 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -107,6 +107,10 @@ Decoder-only Language Models - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. - + * - :code:`MiniCPM3ForCausalLM` + - MiniCPM3 + - :code:`openbmb/MiniCPM3-4B` + - * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 4db84702956..c49b69fff2b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -43,6 +43,7 @@ "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py new file mode 100644 index 00000000000..be35b2e3d6c --- /dev/null +++ b/vllm/model_executor/models/minicpm3.py @@ -0,0 +1,604 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2024 The ModelBest team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM3 model compatible with HuggingFace weights.""" +import math +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors + + +class MiniCPMMoE(nn.Module): + """A tensor-parallel MoE implementation that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=True, + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class MiniCPMMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiniCPMAttention(nn.Module): + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + + tp_size = get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config) + + self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size, + self.kv_lora_rank + + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config) + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config) + + self.rotary_emb = get_rope( + self.qk_rope_head_dim, + rotary_dim=self.qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_pe = latent_cache[:, :, self.kv_lora_rank:] + + q_pe, k_pe = self.rotary_emb( + positions, + q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim), + k_pe.reshape(-1, self.qk_rope_head_dim)) + q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim) + k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) + + q[..., self.qk_nope_head_dim:] = q_pe + + k = torch.empty_like(q) + + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + + q = q.reshape(-1, self.num_local_heads * self.qk_head_dim) + k = k.view(-1, self.num_local_heads * self.qk_head_dim) + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPMDecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = MiniCPMAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + self.num_experts = getattr(self.config, "num_experts", 0) + if self.num_experts == 0: + self.mlp = MiniCPMMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + else: + self.mlp = MiniCPMMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + return hidden_states, None + + +class MiniCPMModel(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + MiniCPMDecoderLayer(config, cache_config, quant_config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding = self.embed_tokens(input_ids) + return embedding * self.config.scale_emb + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + residual, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MiniCPM3ForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.num_experts = getattr(self.config, "num_experts", 0) + self.quant_config = quant_config + self.model = MiniCPMModel(config, + cache_config, + quant_config, + lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + if not self.config.tie_word_embeddings: + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + self.scale_width = self.config.hidden_size / self.config.dim_model_base + + self.logits_processor = LogitsProcessor(unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + hidden_states = hidden_states / self.scale_width + if self.config.tie_word_embeddings: + lm_head = self.model.embed_tokens + else: + lm_head = self.lm_head + logits = self.logits_processor(lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 4ad985ba34afa0ccdbc5e2e8299a452f522764bc Mon Sep 17 00:00:00 2001 From: ywfang <1157670798@qq.com> Date: Fri, 13 Sep 2024 16:21:52 +0800 Subject: [PATCH 02/11] optimize code structure --- vllm/model_executor/models/minicpm3.py | 349 ++----------------------- 1 file changed, 22 insertions(+), 327 deletions(-) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index be35b2e3d6c..73d7556fabc 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -22,7 +22,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM3 model compatible with HuggingFace weights.""" -import math from typing import Any, Dict, Iterable, List, Optional, Tuple import torch @@ -31,149 +30,25 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.models.minicpm import (MiniCPMMLP, MiniCPMMoE, + MiniCPMDecoderLayer, + MiniCPMModel, + MiniCPMForCausalLM) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors -class MiniCPMMoE(nn.Module): - """A tensor-parallel MoE implementation that shards each expert - across all ranks. - - Each expert's weights are sharded across all ranks and a fused MoE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ - - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - ): - super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - self.gate = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - quant_config=None) - - self.ws = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda", - dtype=self.params_dtype)) - self.w2s = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) - - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=True, - inplace=True) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) - - -class MiniCPMMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class MiniCPMAttention(nn.Module): +class MiniCPM3Attention(nn.Module): def __init__( self, @@ -247,7 +122,9 @@ def __init__( self.attn = Attention(self.num_local_heads, self.qk_head_dim, self.scaling, - num_kv_heads=self.num_local_heads) + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -304,7 +181,7 @@ def forward( return output -class MiniCPMDecoderLayer(nn.Module): +class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): def __init__( self, @@ -312,14 +189,14 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: - super().__init__() + nn.Module.__init__(self) self.config = config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.self_attn = MiniCPMAttention( + self.self_attn = MiniCPM3Attention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -352,37 +229,8 @@ def __init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) - return hidden_states, None - - -class MiniCPMModel(nn.Module): +class MiniCPM3Model(MiniCPMModel): def __init__( self, @@ -391,7 +239,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__() + nn.Module.__init__(self) self.config = config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * @@ -404,70 +252,13 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, cache_config, quant_config) + MiniCPM3DecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - embedding = self.embed_tokens(input_ids) - return embedding * self.config.scale_emb - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - attn_metadata, - residual, - ) - hidden_states = self.norm(hidden_states) - return hidden_states - -class MiniCPM3ForCausalLM(nn.Module): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] +class MiniCPM3ForCausalLM(MiniCPMForCausalLM): def __init__( self, @@ -476,17 +267,17 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__() + nn.Module.__init__(self) self.config = config self.lora_config = lora_config self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config - self.model = MiniCPMModel(config, - cache_config, - quant_config, - lora_config=lora_config) + self.model = MiniCPM3Model(config, + cache_config, + quant_config, + lora_config=lora_config) unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -506,99 +297,3 @@ def __init__( self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - hidden_states = hidden_states / self.scale_width - if self.config.tie_word_embeddings: - lm_head = self.model.embed_tokens - else: - lm_head = self.lm_head - logits = self.logits_processor(lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - expert_params_mapping = [ - # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.num_experts) - for weight_name in ["w1", "w2", "w3"] - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, expert_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) From 915381743b2a3833addd72aa89a5cb900314e596 Mon Sep 17 00:00:00 2001 From: ywfang <1157670798@qq.com> Date: Fri, 13 Sep 2024 16:29:21 +0800 Subject: [PATCH 03/11] fix format --- vllm/model_executor/models/minicpm3.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 73d7556fabc..a1d15256c73 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM3 model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Optional import torch from torch import nn @@ -36,16 +36,15 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.models.minicpm import (MiniCPMMLP, MiniCPMMoE, - MiniCPMDecoderLayer, - MiniCPMModel, - MiniCPMForCausalLM) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, + MiniCPMForCausalLM, MiniCPMMLP, + MiniCPMModel, MiniCPMMoE) class MiniCPM3Attention(nn.Module): From 5f54971466e830f5e72a89b986b8c29ffc1cf3ad Mon Sep 17 00:00:00 2001 From: ywfang <1157670798@qq.com> Date: Fri, 13 Sep 2024 17:26:00 +0800 Subject: [PATCH 04/11] optimize code structure of minicpm series --- vllm/model_executor/models/minicpm.py | 79 ++++++++------ vllm/model_executor/models/minicpm3.py | 136 +++++-------------------- 2 files changed, 76 insertions(+), 139 deletions(-) diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index a135118bc74..963ad7553fe 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -270,38 +270,47 @@ def __init__( ) -> None: super().__init__() self.config = config + self.cache_config = cache_config + self.quant_config = quant_config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self._init_attn_block() + self._init_ffn_block() + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) self.self_attn = MiniCPMAttention( hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, + num_heads=self.config.num_attention_heads, + num_kv_heads=self.config.num_key_value_heads, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, ) + + def _init_ffn_block(self): + self.post_attention_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: self.mlp = MiniCPMMLP( hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, + intermediate_size=self.config.intermediate_size, + hidden_act=self.config.hidden_act, + quant_config=self.quant_config, ) else: - self.mlp = MiniCPMMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mlp = MiniCPMMoE( + num_experts=self.config.num_experts, + top_k=self.config.num_experts_per_tok, + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size) def forward( self, @@ -344,6 +353,8 @@ def __init__( ) -> None: super().__init__() self.config = config + self.cache_config = cache_config + self.quant_config = quant_config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -354,11 +365,15 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) + self._init_layers() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def _init_layers(self): self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) + MiniCPMDecoderLayer(self.config, self.cache_config, + self.quant_config) + for _ in range(self.config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -431,13 +446,11 @@ def __init__( self.config = config self.lora_config = lora_config + self.cache_config = cache_config + self.quant_config = quant_config self.num_experts = getattr(self.config, "num_experts", 0) - self.quant_config = quant_config - self.model = MiniCPMModel(config, - cache_config, - quant_config, - lora_config=lora_config) + self._init_model() unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -458,6 +471,12 @@ def __init__( config.vocab_size) self.sampler = Sampler() + def _init_model(self): + self.model = MiniCPMModel(config=self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + lora_config=self.lora_config) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index a1d15256c73..74feaf45a2f 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -26,25 +26,20 @@ import torch from torch import nn -from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, - MiniCPMForCausalLM, MiniCPMMLP, - MiniCPMModel, MiniCPMMoE) + MiniCPMForCausalLM, + MiniCPMModel) class MiniCPM3Attention(nn.Module): @@ -182,117 +177,40 @@ def forward( class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): - def __init__( - self, - config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - nn.Module.__init__(self) - self.config = config - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) self.self_attn = MiniCPM3Attention( - config=config, + config=self.config, hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, + num_heads=self.config.num_attention_heads, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, ) - self.num_experts = getattr(self.config, "num_experts", 0) - if self.num_experts == 0: - self.mlp = MiniCPMMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - else: - self.mlp = MiniCPMMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) class MiniCPM3Model(MiniCPMModel): - def __init__( - self, - config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - nn.Module.__init__(self) - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) + def _init_layers(self): self.layers = nn.ModuleList([ - MiniCPM3DecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) + MiniCPM3DecoderLayer(self.config, self.cache_config, + self.quant_config) + for _ in range(self.config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) class MiniCPM3ForCausalLM(MiniCPMForCausalLM): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - nn.Module.__init__(self) - - self.config = config - self.lora_config = lora_config - - self.num_experts = getattr(self.config, "num_experts", 0) - self.quant_config = quant_config - self.model = MiniCPM3Model(config, - cache_config, - quant_config, - lora_config=lora_config) - unpadded_vocab_size = config.vocab_size - if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size - if not self.config.tie_word_embeddings: - self.lm_head = ParallelLMHead( - unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - quant_config=quant_config, - ) - self.scale_width = self.config.hidden_size / self.config.dim_model_base - - self.logits_processor = LogitsProcessor(unpadded_vocab_size, - config.vocab_size) - self.sampler = Sampler() + def _init_model(self): + self.model = MiniCPM3Model(config=self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + lora_config=self.lora_config) From 927adaee473227a01ed396d77f5691cc10953fd3 Mon Sep 17 00:00:00 2001 From: ywfang <1157670798@qq.com> Date: Sat, 14 Sep 2024 16:25:41 +0800 Subject: [PATCH 05/11] * fix tupe unpack * add minicpm3 model test --- .../decoder_only/language/test_minicpm3.py | 52 +++++++++++++++++++ vllm/model_executor/models/minicpm3.py | 6 +-- 2 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 tests/models/decoder_only/language/test_minicpm3.py diff --git a/tests/models/decoder_only/language/test_minicpm3.py b/tests/models/decoder_only/language/test_minicpm3.py new file mode 100644 index 00000000000..d624f1918b9 --- /dev/null +++ b/tests/models/decoder_only/language/test_minicpm3.py @@ -0,0 +1,52 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +This tests bigger models and use half precision. + +Run `pytest tests/models/decoder_only/language/test_minicpm3.py`. +""" +import pytest + +from ...utils import check_outputs_equal + +MODELS = [ + "openbmb/MiniCPM3-4B", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 74feaf45a2f..cf62b68ac32 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -127,18 +127,18 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - q = self.q_a_proj(hidden_states)[0] + q, _ = self.q_a_proj(hidden_states) q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states) kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) - kv = self.kv_b_proj(kv_a)[0] + kv, _ = self.kv_b_proj(kv_a) kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) From 6d0f2be170785bcdd647d56102c335378653d099 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 14 Sep 2024 09:57:57 +0000 Subject: [PATCH 06/11] tuple unpacking --- vllm/model_executor/models/minicpm3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index cf62b68ac32..a048a3dba04 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -129,8 +129,8 @@ def forward( ) -> torch.Tensor: q, _ = self.q_a_proj(hidden_states) q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, - self.qk_head_dim) + q, _ = self.q_b_proj(q) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states) From 6fd8c636b0d702d9f873c213d118564595b87b3b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 14 Sep 2024 09:58:01 +0000 Subject: [PATCH 07/11] Fix test dependency --- requirements-test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-test.txt b/requirements-test.txt index ca3bfa7aff6..16a883b81ce 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -21,6 +21,7 @@ compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test +datamodel_code_generator # required for minicpm3 test # TODO: Add this after fully implementing llava(mantis) # git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test From 403aef24aa0eb62beb5524b306525b5514907018 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 14 Sep 2024 10:08:54 +0000 Subject: [PATCH 08/11] Consolidate tests; use eager mode to avoid OOM --- .../decoder_only/language/test_big_models.py | 5 +- .../decoder_only/language/test_minicpm3.py | 52 ------------------- 2 files changed, 3 insertions(+), 54 deletions(-) delete mode 100644 tests/models/decoder_only/language/test_minicpm3.py diff --git a/tests/models/decoder_only/language/test_big_models.py b/tests/models/decoder_only/language/test_big_models.py index c5783fe19dd..c5c10eeecde 100644 --- a/tests/models/decoder_only/language/test_big_models.py +++ b/tests/models/decoder_only/language/test_big_models.py @@ -17,6 +17,7 @@ "EleutherAI/gpt-j-6b", # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, + "openbmb/MiniCPM3-4B", ] #TODO: remove this after CPU float16 support ready @@ -39,7 +40,7 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( @@ -57,7 +58,7 @@ def test_model_print( model: str, dtype: str, ) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: # This test is for verifying whether the model's extra_repr # can be printed correctly. print(vllm_model.model.llm_engine.model_executor.driver_worker. diff --git a/tests/models/decoder_only/language/test_minicpm3.py b/tests/models/decoder_only/language/test_minicpm3.py deleted file mode 100644 index d624f1918b9..00000000000 --- a/tests/models/decoder_only/language/test_minicpm3.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Compare the outputs of HF and vLLM when using greedy sampling. - -This tests bigger models and use half precision. - -Run `pytest tests/models/decoder_only/language/test_minicpm3.py`. -""" -import pytest - -from ...utils import check_outputs_equal - -MODELS = [ - "openbmb/MiniCPM3-4B", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [32]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_model_print( - vllm_runner, - model: str, - dtype: str, -) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) From e64b8ece5075adea6b84682562aa057454c30d14 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 14 Sep 2024 10:39:12 +0000 Subject: [PATCH 09/11] Add missing dependency --- .buildkite/run-cpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f4ead8d2777..73ce82c5857 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -22,7 +22,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " - pip install pytest matplotlib einops transformers_stream_generator + pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator pytest -v -s tests/models/decoder_only/language \ --ignore=tests/models/test_fp8.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \ From 13ba77ba62e1a6293c2bc5ecea9a4f7faa56d2bd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 14 Sep 2024 12:28:39 +0000 Subject: [PATCH 10/11] Skip minicpm3 in cpu; update docs --- docs/source/models/supported_models.rst | 2 +- tests/models/decoder_only/language/test_big_models.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 11c122f839b..3dcc2428037 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -109,7 +109,7 @@ Decoder-only Language Models - * - :code:`MiniCPM3ForCausalLM` - MiniCPM3 - - :code:`openbmb/MiniCPM3-4B` + - :code:`openbmb/MiniCPM3-4B`, etc. - * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct diff --git a/tests/models/decoder_only/language/test_big_models.py b/tests/models/decoder_only/language/test_big_models.py index c5c10eeecde..5168210f150 100644 --- a/tests/models/decoder_only/language/test_big_models.py +++ b/tests/models/decoder_only/language/test_big_models.py @@ -5,7 +5,8 @@ Run `pytest tests/models/test_big_models.py`. """ import pytest -import torch + +from vllm.platforms import current_platform from ...utils import check_outputs_equal @@ -21,9 +22,7 @@ ] #TODO: remove this after CPU float16 support ready -target_dtype = "float" -if torch.cuda.is_available(): - target_dtype = "half" +target_dtype = "float" if current_platform.is_cpu() else "half" @pytest.mark.parametrize("model", MODELS) @@ -37,6 +36,9 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + if model.startswith("openbmb/MiniCPM3") and current_platform.is_cpu(): + pytest.skip("MiniCPM requires fused_moe which is not supported by CPU") + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) From 2b14b04afc61bcc0e6b26b4f083b8e0a33318e36 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 14 Sep 2024 12:52:08 +0000 Subject: [PATCH 11/11] Skip for real --- tests/models/decoder_only/language/test_big_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/decoder_only/language/test_big_models.py b/tests/models/decoder_only/language/test_big_models.py index 5168210f150..fcc15863974 100644 --- a/tests/models/decoder_only/language/test_big_models.py +++ b/tests/models/decoder_only/language/test_big_models.py @@ -18,9 +18,12 @@ "EleutherAI/gpt-j-6b", # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, - "openbmb/MiniCPM3-4B", ] +if not current_platform.is_cpu(): + # MiniCPM requires fused_moe which is not supported by CPU + MODELS.append("openbmb/MiniCPM3-4B") + #TODO: remove this after CPU float16 support ready target_dtype = "float" if current_platform.is_cpu() else "half" @@ -36,9 +39,6 @@ def test_models( dtype: str, max_tokens: int, ) -> None: - if model.startswith("openbmb/MiniCPM3") and current_platform.is_cpu(): - pytest.skip("MiniCPM requires fused_moe which is not supported by CPU") - with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)