Skip to content

Commit b3fca42

Browse files
committed
[Model] use AutoWeightsLoader for bloom
Signed-off-by: calvin chen <120380290@qq.com>
1 parent 4ee4826 commit b3fca42

File tree

1 file changed

+38
-33
lines changed

1 file changed

+38
-33
lines changed

vllm/model_executor/models/bloom.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.sequence import IntermediateTensors
4444

4545
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
46-
from .utils import (is_pp_missing_parameter,
46+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers,
4848
maybe_prefix)
4949

@@ -278,6 +278,41 @@ def forward(
278278
hidden_states = self.ln_f(hidden_states)
279279
return hidden_states
280280

281+
def load_weights(self, weights: Iterable[tuple[str,
282+
torch.Tensor]]) -> set[str]:
283+
params_dict = dict(self.named_parameters(remove_duplicate=False))
284+
loaded_params: set[str] = set()
285+
for name, loaded_weight in weights:
286+
if name == "lm_head.weight":
287+
continue
288+
if not name.startswith("transformer."):
289+
name = "transformer." + name
290+
if is_pp_missing_parameter(name, self):
291+
continue
292+
param = params_dict[name]
293+
294+
if "query_key_value" in name:
295+
# NOTE: BLOOM's fused QKV's output_dim has the shape of
296+
# (num_heads * 3 * head_size), while the
297+
# required shape is (3 * num_heads * head_size).
298+
# Thus, we need weight conversion.
299+
output_dim = getattr(param, "output_dim", None)
300+
num_heads = self.config.num_attention_heads
301+
if output_dim is not None:
302+
loaded_weight_shape = loaded_weight.shape
303+
loaded_weight = loaded_weight.view(
304+
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
305+
loaded_weight_shape[output_dim + 1:])
306+
loaded_weight = loaded_weight.transpose(
307+
output_dim, output_dim + 1)
308+
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
309+
310+
weight_loader = getattr(param, "weight_loader",
311+
default_weight_loader)
312+
weight_loader(param, loaded_weight)
313+
loaded_params.add(name)
314+
return loaded_params
315+
281316

282317
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
283318

@@ -325,35 +360,5 @@ def compute_logits(
325360

326361
def load_weights(self, weights: Iterable[tuple[str,
327362
torch.Tensor]]) -> set[str]:
328-
params_dict = dict(self.named_parameters(remove_duplicate=False))
329-
loaded_params: set[str] = set()
330-
for name, loaded_weight in weights:
331-
if name == "lm_head.weight":
332-
continue
333-
if not name.startswith("transformer."):
334-
name = "transformer." + name
335-
if is_pp_missing_parameter(name, self):
336-
continue
337-
param = params_dict[name]
338-
339-
if "query_key_value" in name:
340-
# NOTE: BLOOM's fused QKV's output_dim has the shape of
341-
# (num_heads * 3 * head_size), while the
342-
# required shape is (3 * num_heads * head_size).
343-
# Thus, we need weight conversion.
344-
output_dim = getattr(param, "output_dim", None)
345-
num_heads = self.config.num_attention_heads
346-
if output_dim is not None:
347-
loaded_weight_shape = loaded_weight.shape
348-
loaded_weight = loaded_weight.view(
349-
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
350-
loaded_weight_shape[output_dim + 1:])
351-
loaded_weight = loaded_weight.transpose(
352-
output_dim, output_dim + 1)
353-
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
354-
355-
weight_loader = getattr(param, "weight_loader",
356-
default_weight_loader)
357-
weight_loader(param, loaded_weight)
358-
loaded_params.add(name)
359-
return loaded_params
363+
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"])
364+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)