Skip to content

Commit 1d094fd

Browse files
authored
[Distributed][PP] only create embedding & lm head when necessary (#6455)
original title: [Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization
1 parent ce37be7 commit 1d094fd

File tree

1 file changed

+38
-27
lines changed

1 file changed

+38
-27
lines changed

vllm/model_executor/models/llama.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from vllm.utils import is_hip
5151

5252
from .interfaces import SupportsLoRA
53-
from .utils import is_pp_missing_parameter, make_layers
53+
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
5454

5555

5656
class LlamaMLP(nn.Module):
@@ -257,17 +257,24 @@ def __init__(
257257
(lora_config.max_loras or 1)) if lora_config else 0
258258
self.vocab_size = config.vocab_size + lora_vocab
259259
self.org_vocab_size = config.vocab_size
260-
self.embed_tokens = VocabParallelEmbedding(
261-
self.vocab_size,
262-
config.hidden_size,
263-
org_num_embeddings=config.vocab_size,
264-
)
260+
if get_pp_group().is_first_rank or (config.tie_word_embeddings
261+
and get_pp_group().is_last_rank):
262+
self.embed_tokens = VocabParallelEmbedding(
263+
self.vocab_size,
264+
config.hidden_size,
265+
org_num_embeddings=config.vocab_size,
266+
)
267+
else:
268+
self.embed_tokens = PPMissingLayer()
265269
self.start_layer, self.end_layer, self.layers = make_layers(
266270
config.num_hidden_layers,
267271
lambda: LlamaDecoderLayer(config=config,
268272
cache_config=cache_config,
269273
quant_config=quant_config))
270-
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
274+
if get_pp_group().is_last_rank:
275+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
276+
else:
277+
self.norm = PPMissingLayer()
271278

272279
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
273280
return self.embed_tokens(input_ids)
@@ -360,26 +367,30 @@ def __init__(
360367
cache_config,
361368
quant_config,
362369
lora_config=lora_config)
363-
self.unpadded_vocab_size = config.vocab_size
364-
if lora_config:
365-
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
366-
self.lm_head = ParallelLMHead(
367-
self.unpadded_vocab_size,
368-
config.hidden_size,
369-
org_num_embeddings=config.vocab_size,
370-
padding_size=DEFAULT_VOCAB_PADDING_SIZE
371-
# We need bigger padding if using lora for kernel
372-
# compatibility
373-
if not lora_config else lora_config.lora_vocab_padding_size,
374-
quant_config=quant_config,
375-
)
376-
if config.tie_word_embeddings:
377-
self.lm_head.weight = self.model.embed_tokens.weight
378-
379-
logit_scale = getattr(config, "logit_scale", 1.0)
380-
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
381-
config.vocab_size, logit_scale)
382-
self.sampler = Sampler()
370+
if get_pp_group().is_last_rank:
371+
self.unpadded_vocab_size = config.vocab_size
372+
if lora_config:
373+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
374+
self.lm_head = ParallelLMHead(
375+
self.unpadded_vocab_size,
376+
config.hidden_size,
377+
org_num_embeddings=config.vocab_size,
378+
padding_size=DEFAULT_VOCAB_PADDING_SIZE
379+
# We need bigger padding if using lora for kernel
380+
# compatibility
381+
if not lora_config else lora_config.lora_vocab_padding_size,
382+
quant_config=quant_config,
383+
)
384+
if config.tie_word_embeddings:
385+
self.lm_head.weight = self.model.embed_tokens.weight
386+
387+
logit_scale = getattr(config, "logit_scale", 1.0)
388+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
389+
config.vocab_size,
390+
logit_scale)
391+
self.sampler = Sampler()
392+
else:
393+
self.lm_head = PPMissingLayer()
383394

384395
def forward(
385396
self,

0 commit comments

Comments
 (0)