29
29
30
30
from vllm .attention import Attention , AttentionMetadata
31
31
from vllm .config import CacheConfig , LoRAConfig
32
- from vllm .distributed import get_tensor_model_parallel_world_size
32
+ from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
33
33
from vllm .model_executor .layers .fused_moe import FusedMoE
34
34
from vllm .model_executor .layers .layernorm import RMSNorm
35
35
from vllm .model_executor .layers .linear import (QKVParallelLinear ,
48
48
from vllm .sequence import IntermediateTensors , SamplerOutput
49
49
50
50
from .interfaces import SupportsLoRA
51
+ from .utils import is_pp_missing_parameter , make_layers
51
52
52
53
53
54
class MixtralMoE (nn .Module ):
@@ -255,12 +256,11 @@ def __init__(
255
256
config .hidden_size ,
256
257
org_num_embeddings = config .vocab_size ,
257
258
)
258
- self .layers = nn .ModuleList ([
259
- MixtralDecoderLayer (config ,
260
- cache_config ,
261
- quant_config = quant_config )
262
- for _ in range (config .num_hidden_layers )
263
- ])
259
+
260
+ self .start_layer , self .end_layer , self .layers = make_layers (
261
+ config .num_hidden_layers , lambda : MixtralDecoderLayer (
262
+ config , cache_config , quant_config = quant_config ))
263
+
264
264
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
265
265
266
266
def forward (
@@ -269,14 +269,25 @@ def forward(
269
269
positions : torch .Tensor ,
270
270
kv_caches : List [torch .Tensor ],
271
271
attn_metadata : AttentionMetadata ,
272
+ intermediate_tensors : Optional [IntermediateTensors ],
272
273
) -> torch .Tensor :
273
- hidden_states = self .embed_tokens (input_ids )
274
- residual = None
275
- for i in range (len (self .layers )):
274
+ if get_pp_group ().is_first_rank :
275
+ hidden_states = self .embed_tokens (input_ids )
276
+ residual = None
277
+ else :
278
+ assert intermediate_tensors is not None
279
+ hidden_states = intermediate_tensors ["hidden_states" ]
280
+ residual = intermediate_tensors ["residual" ]
281
+ for i in range (self .start_layer , self .end_layer ):
276
282
layer = self .layers [i ]
277
283
hidden_states , residual = layer (positions , hidden_states ,
278
- kv_caches [i ], attn_metadata ,
279
- residual )
284
+ kv_caches [i - self .start_layer ],
285
+ attn_metadata , residual )
286
+ if not get_pp_group ().is_last_rank :
287
+ return IntermediateTensors ({
288
+ "hidden_states" : hidden_states ,
289
+ "residual" : residual
290
+ })
280
291
hidden_states , _ = self .norm (hidden_states , residual )
281
292
return hidden_states
282
293
@@ -347,7 +358,7 @@ def forward(
347
358
intermediate_tensors : Optional [IntermediateTensors ] = None ,
348
359
) -> torch .Tensor :
349
360
hidden_states = self .model (input_ids , positions , kv_caches ,
350
- attn_metadata )
361
+ attn_metadata , intermediate_tensors )
351
362
return hidden_states
352
363
353
364
def compute_logits (self , hidden_states : torch .Tensor ,
@@ -356,6 +367,20 @@ def compute_logits(self, hidden_states: torch.Tensor,
356
367
sampling_metadata )
357
368
return logits
358
369
370
+ def make_empty_intermediate_tensors (
371
+ self , batch_size : int , dtype : torch .dtype ,
372
+ device : torch .device ) -> IntermediateTensors :
373
+ return IntermediateTensors ({
374
+ "hidden_states" :
375
+ torch .zeros ((batch_size , self .config .hidden_size ),
376
+ dtype = dtype ,
377
+ device = device ),
378
+ "residual" :
379
+ torch .zeros ((batch_size , self .config .hidden_size ),
380
+ dtype = dtype ,
381
+ device = device ),
382
+ })
383
+
359
384
def sample (
360
385
self ,
361
386
logits : Optional [torch .Tensor ],
@@ -392,6 +417,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
392
417
# Skip loading extra bias for GPTQ models.
393
418
if name .endswith (".bias" ) and name not in params_dict :
394
419
continue
420
+ # Skip layers on other devices.
421
+ if is_pp_missing_parameter (name , self ):
422
+ continue
423
+
395
424
param = params_dict [name ]
396
425
weight_loader = param .weight_loader
397
426
weight_loader (param , loaded_weight , shard_id )
@@ -402,6 +431,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
402
431
if weight_name not in name :
403
432
continue
404
433
name = name .replace (weight_name , param_name )
434
+ # Skip layers on other devices.
435
+ if is_pp_missing_parameter (name , self ):
436
+ continue
405
437
param = params_dict [name ]
406
438
weight_loader = param .weight_loader
407
439
weight_loader (param ,
@@ -414,6 +446,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
414
446
# Skip loading extra bias for GPTQ models.
415
447
if name .endswith (".bias" ) and name not in params_dict :
416
448
continue
449
+ # Skip layers on other devices.
450
+ if is_pp_missing_parameter (name , self ):
451
+ continue
417
452
# Remapping the name of FP8 kv-scale.
418
453
name = maybe_remap_kv_scale_name (name , params_dict )
419
454
if name is None :
0 commit comments