|
26 | 26 | from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
27 | 27 |
|
28 | 28 | from .interfaces import SupportsPP
|
29 |
| -from .utils import (is_pp_missing_parameter, |
| 29 | +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, |
30 | 30 | make_empty_intermediate_tensors_factory, make_layers,
|
31 | 31 | maybe_prefix)
|
32 | 32 |
|
@@ -319,6 +319,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
319 | 319 | cache_config = vllm_config.cache_config
|
320 | 320 | quant_config = vllm_config.quant_config
|
321 | 321 |
|
| 322 | + self.quant_config = quant_config |
322 | 323 | self.wte = VocabParallelEmbedding(
|
323 | 324 | config.vocab_size,
|
324 | 325 | config.d_model,
|
@@ -364,6 +365,55 @@ def forward(
|
364 | 365 | hidden_states = self.norm_f(hidden_states)
|
365 | 366 | return hidden_states
|
366 | 367 |
|
| 368 | + def load_weights(self, weights: Iterable[tuple[str, |
| 369 | + torch.Tensor]]) -> set[str]: |
| 370 | + expert_params_mapping = [( |
| 371 | + "w13" if weight_name in ["w1", "v1"] else "w2", |
| 372 | + f"mlp.{weight_name}", |
| 373 | + ) for weight_name in ["w1", "v1", "w2"]] |
| 374 | + params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| 375 | + loaded_params: set[str] = set() |
| 376 | + |
| 377 | + for name, loaded_weight in weights: |
| 378 | + if (self.quant_config is not None and |
| 379 | + (scale_name := self.quant_config.get_cache_scale(name))): |
| 380 | + # Loading kv cache quantization scales |
| 381 | + param = params_dict[scale_name] |
| 382 | + weight_loader = getattr(param, "weight_loader", |
| 383 | + default_weight_loader) |
| 384 | + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
| 385 | + loaded_weight[0]) |
| 386 | + weight_loader(param, loaded_weight) |
| 387 | + loaded_params.add(scale_name) |
| 388 | + continue |
| 389 | + |
| 390 | + if name.endswith(("w1", "w2", "v1")): |
| 391 | + name = name + "_weight" |
| 392 | + for param_name, weight_name in expert_params_mapping: |
| 393 | + if weight_name not in name: |
| 394 | + continue |
| 395 | + name = name.replace(weight_name, param_name) |
| 396 | + if is_pp_missing_parameter(name, self): |
| 397 | + continue |
| 398 | + param = params_dict[name] |
| 399 | + weight_loader = param.weight_loader |
| 400 | + weight_loader(param, loaded_weight, weight_name, name) |
| 401 | + break |
| 402 | + |
| 403 | + else: |
| 404 | + if is_pp_missing_parameter(name, self): |
| 405 | + continue |
| 406 | + # Remapping the name of FP8 kv-scale. |
| 407 | + name = maybe_remap_kv_scale_name(name, params_dict) |
| 408 | + if name is None: |
| 409 | + continue |
| 410 | + param = params_dict[name] |
| 411 | + weight_loader = getattr(param, "weight_loader", |
| 412 | + default_weight_loader) |
| 413 | + weight_loader(param, loaded_weight) |
| 414 | + loaded_params.add(name) |
| 415 | + return loaded_params |
| 416 | + |
367 | 417 |
|
368 | 418 | class DbrxForCausalLM(nn.Module, SupportsPP):
|
369 | 419 |
|
@@ -417,49 +467,5 @@ def compute_logits(
|
417 | 467 |
|
418 | 468 | def load_weights(self, weights: Iterable[tuple[str,
|
419 | 469 | torch.Tensor]]) -> set[str]:
|
420 |
| - expert_params_mapping = [( |
421 |
| - "w13" if weight_name in ["w1", "v1"] else "w2", |
422 |
| - f"mlp.{weight_name}", |
423 |
| - ) for weight_name in ["w1", "v1", "w2"]] |
424 |
| - params_dict = dict(self.named_parameters(remove_duplicate=False)) |
425 |
| - loaded_params: set[str] = set() |
426 |
| - |
427 |
| - for name, loaded_weight in weights: |
428 |
| - if (self.quant_config is not None and |
429 |
| - (scale_name := self.quant_config.get_cache_scale(name))): |
430 |
| - # Loading kv cache quantization scales |
431 |
| - param = params_dict[scale_name] |
432 |
| - weight_loader = getattr(param, "weight_loader", |
433 |
| - default_weight_loader) |
434 |
| - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
435 |
| - loaded_weight[0]) |
436 |
| - weight_loader(param, loaded_weight) |
437 |
| - loaded_params.add(scale_name) |
438 |
| - continue |
439 |
| - |
440 |
| - if name.endswith(("w1", "w2", "v1")): |
441 |
| - name = name + "_weight" |
442 |
| - for param_name, weight_name in expert_params_mapping: |
443 |
| - if weight_name not in name: |
444 |
| - continue |
445 |
| - name = name.replace(weight_name, param_name) |
446 |
| - if is_pp_missing_parameter(name, self): |
447 |
| - continue |
448 |
| - param = params_dict[name] |
449 |
| - weight_loader = param.weight_loader |
450 |
| - weight_loader(param, loaded_weight, weight_name, name) |
451 |
| - break |
452 |
| - |
453 |
| - else: |
454 |
| - if is_pp_missing_parameter(name, self): |
455 |
| - continue |
456 |
| - # Remapping the name of FP8 kv-scale. |
457 |
| - name = maybe_remap_kv_scale_name(name, params_dict) |
458 |
| - if name is None: |
459 |
| - continue |
460 |
| - param = params_dict[name] |
461 |
| - weight_loader = getattr(param, "weight_loader", |
462 |
| - default_weight_loader) |
463 |
| - weight_loader(param, loaded_weight) |
464 |
| - loaded_params.add(name) |
465 |
| - return loaded_params |
| 470 | + loader = AutoWeightsLoader(self) |
| 471 | + return loader.load_weights(weights) |
0 commit comments