|
48 | 48 | from vllm.transformers_utils.configs import NemotronConfig
|
49 | 49 |
|
50 | 50 | from .interfaces import SupportsLoRA, SupportsPP
|
51 |
| -from .utils import (PPMissingLayer, is_pp_missing_parameter, |
| 51 | +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, |
52 | 52 | make_empty_intermediate_tensors_factory, make_layers,
|
53 | 53 | maybe_prefix)
|
54 | 54 |
|
@@ -300,6 +300,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
300 | 300 | lora_config = vllm_config.lora_config
|
301 | 301 |
|
302 | 302 | self.config = config
|
| 303 | + self.quant_config = quant_config |
303 | 304 | lora_vocab = (lora_config.lora_extra_vocab_size *
|
304 | 305 | (lora_config.max_loras or 1)) if lora_config else 0
|
305 | 306 | self.vocab_size = config.vocab_size + lora_vocab
|
@@ -362,6 +363,63 @@ def forward(
|
362 | 363 | hidden_states, _ = self.norm(hidden_states, residual)
|
363 | 364 | return hidden_states
|
364 | 365 |
|
| 366 | + def load_weights(self, weights: Iterable[tuple[str, |
| 367 | + torch.Tensor]]) -> set[str]: |
| 368 | + stacked_params_mapping = [ |
| 369 | + # (param_name, shard_name, shard_id) |
| 370 | + (".qkv_proj", ".q_proj", "q"), |
| 371 | + (".qkv_proj", ".k_proj", "k"), |
| 372 | + (".qkv_proj", ".v_proj", "v"), |
| 373 | + ] |
| 374 | + params_dict = dict(self.named_parameters()) |
| 375 | + loaded_params: set[str] = set() |
| 376 | + for name, loaded_weight in weights: |
| 377 | + if (self.quant_config is not None and |
| 378 | + (scale_name := self.quant_config.get_cache_scale(name))): |
| 379 | + # Loading kv cache quantization scales |
| 380 | + param = params_dict[scale_name] |
| 381 | + weight_loader = getattr(param, "weight_loader", |
| 382 | + default_weight_loader) |
| 383 | + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
| 384 | + loaded_weight[0]) |
| 385 | + weight_loader(param, loaded_weight) |
| 386 | + loaded_params.add(scale_name) |
| 387 | + continue |
| 388 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 389 | + if weight_name not in name: |
| 390 | + continue |
| 391 | + name = name.replace(weight_name, param_name) |
| 392 | + # Skip loading extra bias for GPTQ models. |
| 393 | + if name.endswith(".bias") and name not in params_dict: |
| 394 | + continue |
| 395 | + |
| 396 | + if is_pp_missing_parameter(name, self): |
| 397 | + continue |
| 398 | + |
| 399 | + param = params_dict[name] |
| 400 | + weight_loader = param.weight_loader |
| 401 | + weight_loader(param, loaded_weight, shard_id) |
| 402 | + |
| 403 | + break |
| 404 | + else: |
| 405 | + # Skip loading extra bias for GPTQ models. |
| 406 | + if name.endswith(".bias") and name not in params_dict: |
| 407 | + continue |
| 408 | + # Remapping the name of FP8 kv-scale. |
| 409 | + name = maybe_remap_kv_scale_name(name, params_dict) |
| 410 | + if name is None: |
| 411 | + continue |
| 412 | + |
| 413 | + if is_pp_missing_parameter(name, self): |
| 414 | + continue |
| 415 | + |
| 416 | + param = params_dict[name] |
| 417 | + weight_loader = getattr(param, "weight_loader", |
| 418 | + default_weight_loader) |
| 419 | + weight_loader(param, loaded_weight) |
| 420 | + loaded_params.add(name) |
| 421 | + return loaded_params |
| 422 | + |
365 | 423 |
|
366 | 424 | class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
367 | 425 | packed_modules_mapping = {
|
@@ -444,64 +502,14 @@ def compute_logits(
|
444 | 502 |
|
445 | 503 | def load_weights(self, weights: Iterable[tuple[str,
|
446 | 504 | torch.Tensor]]) -> set[str]:
|
447 |
| - stacked_params_mapping = [ |
448 |
| - # (param_name, shard_name, shard_id) |
449 |
| - (".qkv_proj", ".q_proj", "q"), |
450 |
| - (".qkv_proj", ".k_proj", "k"), |
451 |
| - (".qkv_proj", ".v_proj", "v"), |
452 |
| - ] |
453 |
| - params_dict = dict(self.named_parameters()) |
454 |
| - loaded_params: set[str] = set() |
455 |
| - for name, loaded_weight in weights: |
456 |
| - if "rotary_emb.inv_freq" in name: |
457 |
| - continue |
458 |
| - if ("rotary_emb.cos_cached" in name |
459 |
| - or "rotary_emb.sin_cached" in name): |
| 505 | + loader = AutoWeightsLoader( |
| 506 | + self, |
| 507 | + skip_prefixes=([ |
| 508 | + "rotary_emb.inv_freq", |
460 | 509 | # Models trained using ColossalAI may include these tensors in
|
461 | 510 | # the checkpoint. Skip them.
|
462 |
| - continue |
463 |
| - if (self.quant_config is not None and |
464 |
| - (scale_name := self.quant_config.get_cache_scale(name))): |
465 |
| - # Loading kv cache quantization scales |
466 |
| - param = params_dict[scale_name] |
467 |
| - weight_loader = getattr(param, "weight_loader", |
468 |
| - default_weight_loader) |
469 |
| - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
470 |
| - loaded_weight[0]) |
471 |
| - weight_loader(param, loaded_weight) |
472 |
| - loaded_params.add(scale_name) |
473 |
| - continue |
474 |
| - for (param_name, weight_name, shard_id) in stacked_params_mapping: |
475 |
| - if weight_name not in name: |
476 |
| - continue |
477 |
| - name = name.replace(weight_name, param_name) |
478 |
| - # Skip loading extra bias for GPTQ models. |
479 |
| - if name.endswith(".bias") and name not in params_dict: |
480 |
| - continue |
481 |
| - |
482 |
| - if is_pp_missing_parameter(name, self): |
483 |
| - continue |
484 |
| - |
485 |
| - param = params_dict[name] |
486 |
| - weight_loader = param.weight_loader |
487 |
| - weight_loader(param, loaded_weight, shard_id) |
488 |
| - |
489 |
| - break |
490 |
| - else: |
491 |
| - # Skip loading extra bias for GPTQ models. |
492 |
| - if name.endswith(".bias") and name not in params_dict: |
493 |
| - continue |
494 |
| - # Remapping the name of FP8 kv-scale. |
495 |
| - name = maybe_remap_kv_scale_name(name, params_dict) |
496 |
| - if name is None: |
497 |
| - continue |
498 |
| - |
499 |
| - if is_pp_missing_parameter(name, self): |
500 |
| - continue |
501 |
| - |
502 |
| - param = params_dict[name] |
503 |
| - weight_loader = getattr(param, "weight_loader", |
504 |
| - default_weight_loader) |
505 |
| - weight_loader(param, loaded_weight) |
506 |
| - loaded_params.add(name) |
507 |
| - return loaded_params |
| 511 | + "rotary_emb.cos_cached", |
| 512 | + "rotary_emb.sin_cached" |
| 513 | + ]), |
| 514 | + ) |
| 515 | + return loader.load_weights(weights) |
0 commit comments