Skip to content

Commit 9214e60

Browse files
authored
[Model] use AutoWeightsLoader for solar (#18113)
1 parent f880d42 commit 9214e60

File tree

5 files changed

+299
-260
lines changed

5 files changed

+299
-260
lines changed

vllm/model_executor/models/mixtral_quant.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from vllm.sequence import IntermediateTensors
5252

5353
from .interfaces import SupportsPP
54-
from .utils import (is_pp_missing_parameter,
54+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
5555
make_empty_intermediate_tensors_factory, make_layers,
5656
maybe_prefix)
5757

@@ -354,50 +354,6 @@ def forward(
354354
hidden_states, _ = self.norm(hidden_states, residual)
355355
return hidden_states
356356

357-
358-
class MixtralForCausalLM(nn.Module, SupportsPP):
359-
fall_back_to_pt_during_load = False
360-
361-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
362-
super().__init__()
363-
config = vllm_config.model_config.hf_config
364-
quant_config = vllm_config.quant_config
365-
self.config = config
366-
self.quant_config = quant_config
367-
self.model = MixtralModel(vllm_config=vllm_config,
368-
prefix=maybe_prefix(prefix, "model"))
369-
self.lm_head = ParallelLMHead(config.vocab_size,
370-
config.hidden_size,
371-
quant_config=quant_config)
372-
if self.config.tie_word_embeddings:
373-
self.lm_head.weight = self.model.embed_tokens.weight
374-
self.logits_processor = LogitsProcessor(config.vocab_size)
375-
self.make_empty_intermediate_tensors = (
376-
self.model.make_empty_intermediate_tensors)
377-
378-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
379-
return self.model.get_input_embeddings(input_ids)
380-
381-
def forward(
382-
self,
383-
input_ids: torch.Tensor,
384-
positions: torch.Tensor,
385-
intermediate_tensors: Optional[IntermediateTensors] = None,
386-
inputs_embeds: Optional[torch.Tensor] = None,
387-
) -> Union[torch.Tensor, IntermediateTensors]:
388-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
389-
inputs_embeds)
390-
return hidden_states
391-
392-
def compute_logits(
393-
self,
394-
hidden_states: torch.Tensor,
395-
sampling_metadata: SamplingMetadata,
396-
) -> Optional[torch.Tensor]:
397-
logits = self.logits_processor(self.lm_head, hidden_states,
398-
sampling_metadata)
399-
return logits
400-
401357
def load_weights(self, weights: Iterable[tuple[str,
402358
torch.Tensor]]) -> set[str]:
403359
stacked_params_mapping = [
@@ -410,8 +366,6 @@ def load_weights(self, weights: Iterable[tuple[str,
410366
params_dict = dict(self.named_parameters())
411367
loaded_params: set[str] = set()
412368
for name, loaded_weight in weights:
413-
if "rotary_emb.inv_freq" in name:
414-
continue
415369
if name.endswith("scale"):
416370
# Remapping the name of FP8 kv-scale.
417371
name = maybe_remap_kv_scale_name(name, params_dict)
@@ -446,3 +400,55 @@ def load_weights(self, weights: Iterable[tuple[str,
446400
weight_loader(param, loaded_weight)
447401
loaded_params.add(name)
448402
return loaded_params
403+
404+
405+
class MixtralForCausalLM(nn.Module, SupportsPP):
406+
fall_back_to_pt_during_load = False
407+
408+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
409+
super().__init__()
410+
config = vllm_config.model_config.hf_config
411+
quant_config = vllm_config.quant_config
412+
self.config = config
413+
self.quant_config = quant_config
414+
self.model = MixtralModel(vllm_config=vllm_config,
415+
prefix=maybe_prefix(prefix, "model"))
416+
self.lm_head = ParallelLMHead(config.vocab_size,
417+
config.hidden_size,
418+
quant_config=quant_config)
419+
if self.config.tie_word_embeddings:
420+
self.lm_head.weight = self.model.embed_tokens.weight
421+
self.logits_processor = LogitsProcessor(config.vocab_size)
422+
self.make_empty_intermediate_tensors = (
423+
self.model.make_empty_intermediate_tensors)
424+
425+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
426+
return self.model.get_input_embeddings(input_ids)
427+
428+
def forward(
429+
self,
430+
input_ids: torch.Tensor,
431+
positions: torch.Tensor,
432+
intermediate_tensors: Optional[IntermediateTensors] = None,
433+
inputs_embeds: Optional[torch.Tensor] = None,
434+
) -> Union[torch.Tensor, IntermediateTensors]:
435+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
436+
inputs_embeds)
437+
return hidden_states
438+
439+
def compute_logits(
440+
self,
441+
hidden_states: torch.Tensor,
442+
sampling_metadata: SamplingMetadata,
443+
) -> Optional[torch.Tensor]:
444+
logits = self.logits_processor(self.lm_head, hidden_states,
445+
sampling_metadata)
446+
return logits
447+
448+
def load_weights(self, weights: Iterable[tuple[str,
449+
torch.Tensor]]) -> set[str]:
450+
loader = AutoWeightsLoader(
451+
self,
452+
skip_prefixes=(["rotary_emb.inv_freq"]),
453+
)
454+
return loader.load_weights(weights)

vllm/model_executor/models/nemotron.py

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from vllm.transformers_utils.configs import NemotronConfig
4949

5050
from .interfaces import SupportsLoRA, SupportsPP
51-
from .utils import (PPMissingLayer, is_pp_missing_parameter,
51+
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
5252
make_empty_intermediate_tensors_factory, make_layers,
5353
maybe_prefix)
5454

@@ -300,6 +300,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
300300
lora_config = vllm_config.lora_config
301301

302302
self.config = config
303+
self.quant_config = quant_config
303304
lora_vocab = (lora_config.lora_extra_vocab_size *
304305
(lora_config.max_loras or 1)) if lora_config else 0
305306
self.vocab_size = config.vocab_size + lora_vocab
@@ -362,6 +363,63 @@ def forward(
362363
hidden_states, _ = self.norm(hidden_states, residual)
363364
return hidden_states
364365

366+
def load_weights(self, weights: Iterable[tuple[str,
367+
torch.Tensor]]) -> set[str]:
368+
stacked_params_mapping = [
369+
# (param_name, shard_name, shard_id)
370+
(".qkv_proj", ".q_proj", "q"),
371+
(".qkv_proj", ".k_proj", "k"),
372+
(".qkv_proj", ".v_proj", "v"),
373+
]
374+
params_dict = dict(self.named_parameters())
375+
loaded_params: set[str] = set()
376+
for name, loaded_weight in weights:
377+
if (self.quant_config is not None and
378+
(scale_name := self.quant_config.get_cache_scale(name))):
379+
# Loading kv cache quantization scales
380+
param = params_dict[scale_name]
381+
weight_loader = getattr(param, "weight_loader",
382+
default_weight_loader)
383+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
384+
loaded_weight[0])
385+
weight_loader(param, loaded_weight)
386+
loaded_params.add(scale_name)
387+
continue
388+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
389+
if weight_name not in name:
390+
continue
391+
name = name.replace(weight_name, param_name)
392+
# Skip loading extra bias for GPTQ models.
393+
if name.endswith(".bias") and name not in params_dict:
394+
continue
395+
396+
if is_pp_missing_parameter(name, self):
397+
continue
398+
399+
param = params_dict[name]
400+
weight_loader = param.weight_loader
401+
weight_loader(param, loaded_weight, shard_id)
402+
403+
break
404+
else:
405+
# Skip loading extra bias for GPTQ models.
406+
if name.endswith(".bias") and name not in params_dict:
407+
continue
408+
# Remapping the name of FP8 kv-scale.
409+
name = maybe_remap_kv_scale_name(name, params_dict)
410+
if name is None:
411+
continue
412+
413+
if is_pp_missing_parameter(name, self):
414+
continue
415+
416+
param = params_dict[name]
417+
weight_loader = getattr(param, "weight_loader",
418+
default_weight_loader)
419+
weight_loader(param, loaded_weight)
420+
loaded_params.add(name)
421+
return loaded_params
422+
365423

366424
class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
367425
packed_modules_mapping = {
@@ -444,64 +502,14 @@ def compute_logits(
444502

445503
def load_weights(self, weights: Iterable[tuple[str,
446504
torch.Tensor]]) -> set[str]:
447-
stacked_params_mapping = [
448-
# (param_name, shard_name, shard_id)
449-
(".qkv_proj", ".q_proj", "q"),
450-
(".qkv_proj", ".k_proj", "k"),
451-
(".qkv_proj", ".v_proj", "v"),
452-
]
453-
params_dict = dict(self.named_parameters())
454-
loaded_params: set[str] = set()
455-
for name, loaded_weight in weights:
456-
if "rotary_emb.inv_freq" in name:
457-
continue
458-
if ("rotary_emb.cos_cached" in name
459-
or "rotary_emb.sin_cached" in name):
505+
loader = AutoWeightsLoader(
506+
self,
507+
skip_prefixes=([
508+
"rotary_emb.inv_freq",
460509
# Models trained using ColossalAI may include these tensors in
461510
# the checkpoint. Skip them.
462-
continue
463-
if (self.quant_config is not None and
464-
(scale_name := self.quant_config.get_cache_scale(name))):
465-
# Loading kv cache quantization scales
466-
param = params_dict[scale_name]
467-
weight_loader = getattr(param, "weight_loader",
468-
default_weight_loader)
469-
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
470-
loaded_weight[0])
471-
weight_loader(param, loaded_weight)
472-
loaded_params.add(scale_name)
473-
continue
474-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
475-
if weight_name not in name:
476-
continue
477-
name = name.replace(weight_name, param_name)
478-
# Skip loading extra bias for GPTQ models.
479-
if name.endswith(".bias") and name not in params_dict:
480-
continue
481-
482-
if is_pp_missing_parameter(name, self):
483-
continue
484-
485-
param = params_dict[name]
486-
weight_loader = param.weight_loader
487-
weight_loader(param, loaded_weight, shard_id)
488-
489-
break
490-
else:
491-
# Skip loading extra bias for GPTQ models.
492-
if name.endswith(".bias") and name not in params_dict:
493-
continue
494-
# Remapping the name of FP8 kv-scale.
495-
name = maybe_remap_kv_scale_name(name, params_dict)
496-
if name is None:
497-
continue
498-
499-
if is_pp_missing_parameter(name, self):
500-
continue
501-
502-
param = params_dict[name]
503-
weight_loader = getattr(param, "weight_loader",
504-
default_weight_loader)
505-
weight_loader(param, loaded_weight)
506-
loaded_params.add(name)
507-
return loaded_params
511+
"rotary_emb.cos_cached",
512+
"rotary_emb.sin_cached"
513+
]),
514+
)
515+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)