|
43 | 43 | from vllm.sequence import IntermediateTensors
|
44 | 44 |
|
45 | 45 | from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
|
46 |
| -from .utils import (is_pp_missing_parameter, |
| 46 | +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, |
47 | 47 | make_empty_intermediate_tensors_factory, make_layers,
|
48 | 48 | maybe_prefix)
|
49 | 49 |
|
@@ -278,6 +278,41 @@ def forward(
|
278 | 278 | hidden_states = self.ln_f(hidden_states)
|
279 | 279 | return hidden_states
|
280 | 280 |
|
| 281 | + def load_weights(self, weights: Iterable[tuple[str, |
| 282 | + torch.Tensor]]) -> set[str]: |
| 283 | + params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| 284 | + loaded_params: set[str] = set() |
| 285 | + for name, loaded_weight in weights: |
| 286 | + if name == "lm_head.weight": |
| 287 | + continue |
| 288 | + if not name.startswith("transformer."): |
| 289 | + name = "transformer." + name |
| 290 | + if is_pp_missing_parameter(name, self): |
| 291 | + continue |
| 292 | + param = params_dict[name] |
| 293 | + |
| 294 | + if "query_key_value" in name: |
| 295 | + # NOTE: BLOOM's fused QKV's output_dim has the shape of |
| 296 | + # (num_heads * 3 * head_size), while the |
| 297 | + # required shape is (3 * num_heads * head_size). |
| 298 | + # Thus, we need weight conversion. |
| 299 | + output_dim = getattr(param, "output_dim", None) |
| 300 | + num_heads = self.config.num_attention_heads |
| 301 | + if output_dim is not None: |
| 302 | + loaded_weight_shape = loaded_weight.shape |
| 303 | + loaded_weight = loaded_weight.view( |
| 304 | + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + |
| 305 | + loaded_weight_shape[output_dim + 1:]) |
| 306 | + loaded_weight = loaded_weight.transpose( |
| 307 | + output_dim, output_dim + 1) |
| 308 | + loaded_weight = loaded_weight.reshape(loaded_weight_shape) |
| 309 | + |
| 310 | + weight_loader = getattr(param, "weight_loader", |
| 311 | + default_weight_loader) |
| 312 | + weight_loader(param, loaded_weight) |
| 313 | + loaded_params.add(name) |
| 314 | + return loaded_params |
| 315 | + |
281 | 316 |
|
282 | 317 | class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
|
283 | 318 |
|
@@ -325,35 +360,5 @@ def compute_logits(
|
325 | 360 |
|
326 | 361 | def load_weights(self, weights: Iterable[tuple[str,
|
327 | 362 | torch.Tensor]]) -> set[str]:
|
328 |
| - params_dict = dict(self.named_parameters(remove_duplicate=False)) |
329 |
| - loaded_params: set[str] = set() |
330 |
| - for name, loaded_weight in weights: |
331 |
| - if name == "lm_head.weight": |
332 |
| - continue |
333 |
| - if not name.startswith("transformer."): |
334 |
| - name = "transformer." + name |
335 |
| - if is_pp_missing_parameter(name, self): |
336 |
| - continue |
337 |
| - param = params_dict[name] |
338 |
| - |
339 |
| - if "query_key_value" in name: |
340 |
| - # NOTE: BLOOM's fused QKV's output_dim has the shape of |
341 |
| - # (num_heads * 3 * head_size), while the |
342 |
| - # required shape is (3 * num_heads * head_size). |
343 |
| - # Thus, we need weight conversion. |
344 |
| - output_dim = getattr(param, "output_dim", None) |
345 |
| - num_heads = self.config.num_attention_heads |
346 |
| - if output_dim is not None: |
347 |
| - loaded_weight_shape = loaded_weight.shape |
348 |
| - loaded_weight = loaded_weight.view( |
349 |
| - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + |
350 |
| - loaded_weight_shape[output_dim + 1:]) |
351 |
| - loaded_weight = loaded_weight.transpose( |
352 |
| - output_dim, output_dim + 1) |
353 |
| - loaded_weight = loaded_weight.reshape(loaded_weight_shape) |
354 |
| - |
355 |
| - weight_loader = getattr(param, "weight_loader", |
356 |
| - default_weight_loader) |
357 |
| - weight_loader(param, loaded_weight) |
358 |
| - loaded_params.add(name) |
359 |
| - return loaded_params |
| 363 | + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) |
| 364 | + return loader.load_weights(weights) |
0 commit comments