|
50 | 50 | from vllm.utils import is_hip
|
51 | 51 |
|
52 | 52 | from .interfaces import SupportsLoRA
|
53 |
| -from .utils import is_pp_missing_parameter, make_layers |
| 53 | +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers |
54 | 54 |
|
55 | 55 |
|
56 | 56 | class LlamaMLP(nn.Module):
|
@@ -257,17 +257,24 @@ def __init__(
|
257 | 257 | (lora_config.max_loras or 1)) if lora_config else 0
|
258 | 258 | self.vocab_size = config.vocab_size + lora_vocab
|
259 | 259 | self.org_vocab_size = config.vocab_size
|
260 |
| - self.embed_tokens = VocabParallelEmbedding( |
261 |
| - self.vocab_size, |
262 |
| - config.hidden_size, |
263 |
| - org_num_embeddings=config.vocab_size, |
264 |
| - ) |
| 260 | + if get_pp_group().is_first_rank or (config.tie_word_embeddings |
| 261 | + and get_pp_group().is_last_rank): |
| 262 | + self.embed_tokens = VocabParallelEmbedding( |
| 263 | + self.vocab_size, |
| 264 | + config.hidden_size, |
| 265 | + org_num_embeddings=config.vocab_size, |
| 266 | + ) |
| 267 | + else: |
| 268 | + self.embed_tokens = PPMissingLayer() |
265 | 269 | self.start_layer, self.end_layer, self.layers = make_layers(
|
266 | 270 | config.num_hidden_layers,
|
267 | 271 | lambda: LlamaDecoderLayer(config=config,
|
268 | 272 | cache_config=cache_config,
|
269 | 273 | quant_config=quant_config))
|
270 |
| - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 274 | + if get_pp_group().is_last_rank: |
| 275 | + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 276 | + else: |
| 277 | + self.norm = PPMissingLayer() |
271 | 278 |
|
272 | 279 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
273 | 280 | return self.embed_tokens(input_ids)
|
@@ -360,26 +367,30 @@ def __init__(
|
360 | 367 | cache_config,
|
361 | 368 | quant_config,
|
362 | 369 | lora_config=lora_config)
|
363 |
| - self.unpadded_vocab_size = config.vocab_size |
364 |
| - if lora_config: |
365 |
| - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size |
366 |
| - self.lm_head = ParallelLMHead( |
367 |
| - self.unpadded_vocab_size, |
368 |
| - config.hidden_size, |
369 |
| - org_num_embeddings=config.vocab_size, |
370 |
| - padding_size=DEFAULT_VOCAB_PADDING_SIZE |
371 |
| - # We need bigger padding if using lora for kernel |
372 |
| - # compatibility |
373 |
| - if not lora_config else lora_config.lora_vocab_padding_size, |
374 |
| - quant_config=quant_config, |
375 |
| - ) |
376 |
| - if config.tie_word_embeddings: |
377 |
| - self.lm_head.weight = self.model.embed_tokens.weight |
378 |
| - |
379 |
| - logit_scale = getattr(config, "logit_scale", 1.0) |
380 |
| - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, |
381 |
| - config.vocab_size, logit_scale) |
382 |
| - self.sampler = Sampler() |
| 370 | + if get_pp_group().is_last_rank: |
| 371 | + self.unpadded_vocab_size = config.vocab_size |
| 372 | + if lora_config: |
| 373 | + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size |
| 374 | + self.lm_head = ParallelLMHead( |
| 375 | + self.unpadded_vocab_size, |
| 376 | + config.hidden_size, |
| 377 | + org_num_embeddings=config.vocab_size, |
| 378 | + padding_size=DEFAULT_VOCAB_PADDING_SIZE |
| 379 | + # We need bigger padding if using lora for kernel |
| 380 | + # compatibility |
| 381 | + if not lora_config else lora_config.lora_vocab_padding_size, |
| 382 | + quant_config=quant_config, |
| 383 | + ) |
| 384 | + if config.tie_word_embeddings: |
| 385 | + self.lm_head.weight = self.model.embed_tokens.weight |
| 386 | + |
| 387 | + logit_scale = getattr(config, "logit_scale", 1.0) |
| 388 | + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, |
| 389 | + config.vocab_size, |
| 390 | + logit_scale) |
| 391 | + self.sampler = Sampler() |
| 392 | + else: |
| 393 | + self.lm_head = PPMissingLayer() |
383 | 394 |
|
384 | 395 | def forward(
|
385 | 396 | self,
|
|
0 commit comments