From cbb706410b2e2d71056884d3bf3a890063414772 Mon Sep 17 00:00:00 2001 From: calvin chen <120380290@qq.com> Date: Sat, 17 May 2025 16:14:06 +0800 Subject: [PATCH] [Model] use AutoWeightsLoader for bloom Signed-off-by: calvin chen <120380290@qq.com> --- vllm/model_executor/models/bloom.py | 79 +++++++++++++++++------------ 1 file changed, 46 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index eb1085d6b40..10424e218fb 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -43,7 +43,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -229,6 +229,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.config = config self.embed_dim = config.hidden_size @@ -278,6 +279,38 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): @@ -325,35 +358,15 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if name == "lm_head.weight": - continue - if not name.startswith("transformer."): - name = "transformer." + name - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - - if "query_key_value" in name: - # NOTE: BLOOM's fused QKV's output_dim has the shape of - # (num_heads * 3 * head_size), while the - # required shape is (3 * num_heads * head_size). - # Thus, we need weight conversion. - output_dim = getattr(param, "output_dim", None) - num_heads = self.config.num_attention_heads - if output_dim is not None: - loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) - loaded_weight = loaded_weight.reshape(loaded_weight_shape) - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) + weights = _add_transformer_prefix(weights) + return loader.load_weights(weights) + + +def _add_transformer_prefix( + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: + for name, tensor in weights: + if not name.startswith('transformer.'): + name = 'transformer.' + name + yield name, tensor