Skip to content

Commit fdc0b8b

Browse files
committed
use AutoWeightsLoader for dbrx
Signed-off-by: learner0810 <zhongjun.li@daocloud.io>
1 parent a8f5aec commit fdc0b8b

File tree

1 file changed

+53
-47
lines changed

1 file changed

+53
-47
lines changed

vllm/model_executor/models/dbrx.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.transformers_utils.configs.dbrx import DbrxConfig
2727

2828
from .interfaces import SupportsPP
29-
from .utils import (is_pp_missing_parameter,
29+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
3030
make_empty_intermediate_tensors_factory, make_layers,
3131
maybe_prefix)
3232

@@ -319,6 +319,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
319319
cache_config = vllm_config.cache_config
320320
quant_config = vllm_config.quant_config
321321

322+
self.quant_config = quant_config
322323
self.wte = VocabParallelEmbedding(
323324
config.vocab_size,
324325
config.d_model,
@@ -364,6 +365,55 @@ def forward(
364365
hidden_states = self.norm_f(hidden_states)
365366
return hidden_states
366367

368+
def load_weights(self, weights: Iterable[tuple[str,
369+
torch.Tensor]]) -> set[str]:
370+
expert_params_mapping = [(
371+
"w13" if weight_name in ["w1", "v1"] else "w2",
372+
f"mlp.{weight_name}",
373+
) for weight_name in ["w1", "v1", "w2"]]
374+
params_dict = dict(self.named_parameters(remove_duplicate=False))
375+
loaded_params: set[str] = set()
376+
377+
for name, loaded_weight in weights:
378+
if (self.quant_config is not None and
379+
(scale_name := self.quant_config.get_cache_scale(name))):
380+
# Loading kv cache quantization scales
381+
param = params_dict[scale_name]
382+
weight_loader = getattr(param, "weight_loader",
383+
default_weight_loader)
384+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
385+
loaded_weight[0])
386+
weight_loader(param, loaded_weight)
387+
loaded_params.add(scale_name)
388+
continue
389+
390+
if name.endswith(("w1", "w2", "v1")):
391+
name = name + "_weight"
392+
for param_name, weight_name in expert_params_mapping:
393+
if weight_name not in name:
394+
continue
395+
name = name.replace(weight_name, param_name)
396+
if is_pp_missing_parameter(name, self):
397+
continue
398+
param = params_dict[name]
399+
weight_loader = param.weight_loader
400+
weight_loader(param, loaded_weight, weight_name, name)
401+
break
402+
403+
else:
404+
if is_pp_missing_parameter(name, self):
405+
continue
406+
# Remapping the name of FP8 kv-scale.
407+
name = maybe_remap_kv_scale_name(name, params_dict)
408+
if name is None:
409+
continue
410+
param = params_dict[name]
411+
weight_loader = getattr(param, "weight_loader",
412+
default_weight_loader)
413+
weight_loader(param, loaded_weight)
414+
loaded_params.add(name)
415+
return loaded_params
416+
367417

368418
class DbrxForCausalLM(nn.Module, SupportsPP):
369419

@@ -417,49 +467,5 @@ def compute_logits(
417467

418468
def load_weights(self, weights: Iterable[tuple[str,
419469
torch.Tensor]]) -> set[str]:
420-
expert_params_mapping = [(
421-
"w13" if weight_name in ["w1", "v1"] else "w2",
422-
f"mlp.{weight_name}",
423-
) for weight_name in ["w1", "v1", "w2"]]
424-
params_dict = dict(self.named_parameters(remove_duplicate=False))
425-
loaded_params: set[str] = set()
426-
427-
for name, loaded_weight in weights:
428-
if (self.quant_config is not None and
429-
(scale_name := self.quant_config.get_cache_scale(name))):
430-
# Loading kv cache quantization scales
431-
param = params_dict[scale_name]
432-
weight_loader = getattr(param, "weight_loader",
433-
default_weight_loader)
434-
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
435-
loaded_weight[0])
436-
weight_loader(param, loaded_weight)
437-
loaded_params.add(scale_name)
438-
continue
439-
440-
if name.endswith(("w1", "w2", "v1")):
441-
name = name + "_weight"
442-
for param_name, weight_name in expert_params_mapping:
443-
if weight_name not in name:
444-
continue
445-
name = name.replace(weight_name, param_name)
446-
if is_pp_missing_parameter(name, self):
447-
continue
448-
param = params_dict[name]
449-
weight_loader = param.weight_loader
450-
weight_loader(param, loaded_weight, weight_name, name)
451-
break
452-
453-
else:
454-
if is_pp_missing_parameter(name, self):
455-
continue
456-
# Remapping the name of FP8 kv-scale.
457-
name = maybe_remap_kv_scale_name(name, params_dict)
458-
if name is None:
459-
continue
460-
param = params_dict[name]
461-
weight_loader = getattr(param, "weight_loader",
462-
default_weight_loader)
463-
weight_loader(param, loaded_weight)
464-
loaded_params.add(name)
465-
return loaded_params
470+
loader = AutoWeightsLoader(self)
471+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)