Skip to content

Commit 8ea1861

Browse files
committed
[Model] use AutoWeightsLoader for bloom
Signed-off-by: calvin chen <120380290@qq.com>
1 parent 4ee4826 commit 8ea1861

File tree

1 file changed

+37
-33
lines changed

1 file changed

+37
-33
lines changed

vllm/model_executor/models/bloom.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.sequence import IntermediateTensors
4444

4545
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
46-
from .utils import (is_pp_missing_parameter,
46+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers,
4848
maybe_prefix)
4949

@@ -229,6 +229,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
229229
config = vllm_config.model_config.hf_config
230230
cache_config = vllm_config.cache_config
231231
quant_config = vllm_config.quant_config
232+
self.config = config
232233

233234
self.embed_dim = config.hidden_size
234235

@@ -278,6 +279,39 @@ def forward(
278279
hidden_states = self.ln_f(hidden_states)
279280
return hidden_states
280281

282+
def load_weights(self, weights: Iterable[tuple[str,
283+
torch.Tensor]]) -> set[str]:
284+
params_dict = dict(self.named_parameters(remove_duplicate=False))
285+
loaded_params: set[str] = set()
286+
for name, loaded_weight in weights:
287+
if not name.startswith("transformer."):
288+
name = "transformer." + name
289+
if is_pp_missing_parameter(name, self):
290+
continue
291+
param = params_dict[name]
292+
293+
if "query_key_value" in name:
294+
# NOTE: BLOOM's fused QKV's output_dim has the shape of
295+
# (num_heads * 3 * head_size), while the
296+
# required shape is (3 * num_heads * head_size).
297+
# Thus, we need weight conversion.
298+
output_dim = getattr(param, "output_dim", None)
299+
num_heads = self.config.num_attention_heads
300+
if output_dim is not None:
301+
loaded_weight_shape = loaded_weight.shape
302+
loaded_weight = loaded_weight.view(
303+
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
304+
loaded_weight_shape[output_dim + 1:])
305+
loaded_weight = loaded_weight.transpose(
306+
output_dim, output_dim + 1)
307+
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
308+
309+
weight_loader = getattr(param, "weight_loader",
310+
default_weight_loader)
311+
weight_loader(param, loaded_weight)
312+
loaded_params.add(name)
313+
return loaded_params
314+
281315

282316
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
283317

@@ -325,35 +359,5 @@ def compute_logits(
325359

326360
def load_weights(self, weights: Iterable[tuple[str,
327361
torch.Tensor]]) -> set[str]:
328-
params_dict = dict(self.named_parameters(remove_duplicate=False))
329-
loaded_params: set[str] = set()
330-
for name, loaded_weight in weights:
331-
if name == "lm_head.weight":
332-
continue
333-
if not name.startswith("transformer."):
334-
name = "transformer." + name
335-
if is_pp_missing_parameter(name, self):
336-
continue
337-
param = params_dict[name]
338-
339-
if "query_key_value" in name:
340-
# NOTE: BLOOM's fused QKV's output_dim has the shape of
341-
# (num_heads * 3 * head_size), while the
342-
# required shape is (3 * num_heads * head_size).
343-
# Thus, we need weight conversion.
344-
output_dim = getattr(param, "output_dim", None)
345-
num_heads = self.config.num_attention_heads
346-
if output_dim is not None:
347-
loaded_weight_shape = loaded_weight.shape
348-
loaded_weight = loaded_weight.view(
349-
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
350-
loaded_weight_shape[output_dim + 1:])
351-
loaded_weight = loaded_weight.transpose(
352-
output_dim, output_dim + 1)
353-
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
354-
355-
weight_loader = getattr(param, "weight_loader",
356-
default_weight_loader)
357-
weight_loader(param, loaded_weight)
358-
loaded_params.add(name)
359-
return loaded_params
362+
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"])
363+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)