20
20
import torch
21
21
import torch_npu
22
22
from vllm .config import get_current_vllm_config
23
- from vllm .distributed import tensor_model_parallel_all_reduce
23
+ from vllm .distributed import (get_tensor_model_parallel_world_size ,
24
+ tensor_model_parallel_all_reduce )
24
25
from vllm .distributed .parallel_state import get_dp_group
25
26
from vllm .model_executor .layers .fused_moe .layer import (
26
27
FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
27
- from vllm .model_executor .layers .quantization .base_config import \
28
- QuantizeMethodBase
28
+
29
+ from vllm_ascend .utils import vllm_version_is
30
+
31
+ if not (vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" )):
32
+ from vllm .model_executor .layers .fused_moe .layer import (
33
+ FusedMoEParallelConfig , MoEConfig )
34
+ else :
35
+ MoEConfig = None
36
+
37
+ from vllm .model_executor .layers .quantization .base_config import (
38
+ QuantizationConfig , QuantizeMethodBase )
29
39
30
40
import vllm_ascend .envs as envs_ascend
31
41
from vllm_ascend .distributed .parallel_state import get_ep_group , get_etp_group
@@ -437,8 +447,11 @@ def select_experts(
437
447
438
448
class AscendUnquantizedFusedMoEMethod (UnquantizedFusedMoEMethod ):
439
449
440
- def __init__ (self ):
441
- super ().__init__ ()
450
+ def __init__ (self , moe : MoEConfig = None ):
451
+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
452
+ super ().__init__ ()
453
+ else :
454
+ super ().__init__ (moe = moe )
442
455
vllm_config = get_current_vllm_config ()
443
456
444
457
ep_group = get_ep_group ()
@@ -535,37 +548,54 @@ def apply(
535
548
536
549
class AscendFusedMoE (FusedMoE ):
537
550
538
- def __init__ (self ,
539
- num_experts ,
540
- top_k ,
541
- hidden_size ,
542
- intermediate_size ,
543
- params_dtype = None ,
544
- reduce_results = False ,
545
- renormalize = True ,
546
- use_grouped_topk = False ,
547
- num_expert_group = None ,
548
- topk_group = None ,
549
- quant_config = None ,
550
- tp_size = None ,
551
- ep_size = None ,
552
- dp_size = None ,
553
- prefix = "" ,
554
- custom_routing_function = None ,
555
- scoring_func = "softmax" ,
556
- e_score_correction_bias = None ,
557
- activation = "silu" ):
551
+ def __init__ (
552
+ self ,
553
+ num_experts : int , # Global number of experts
554
+ top_k : int ,
555
+ hidden_size : int ,
556
+ intermediate_size : int ,
557
+ params_dtype : Optional [torch .dtype ] = None ,
558
+ reduce_results : bool = False ,
559
+ renormalize : bool = True ,
560
+ use_grouped_topk : bool = False ,
561
+ num_expert_group : Optional [int ] = None ,
562
+ topk_group : Optional [int ] = None ,
563
+ quant_config : Optional [QuantizationConfig ] = None ,
564
+ tp_size : Optional [int ] = None ,
565
+ ep_size : Optional [int ] = None ,
566
+ dp_size : Optional [int ] = None ,
567
+ prefix : str = "" ,
568
+ custom_routing_function : Optional [Callable ] = None ,
569
+ scoring_func : str = "softmax" ,
570
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
571
+ activation : str = "silu" ,
572
+ apply_router_weight_on_input : bool = False ,
573
+ ):
574
+ # TODO: This could not initialize FusedMoE baseclass,
575
+ # fixme and make __init__() of AscendFusedMoE more clear
558
576
super (FusedMoE , self ).__init__ ()
559
577
560
578
if params_dtype is None :
561
579
params_dtype = torch .get_default_dtype ()
562
580
563
- self .ep_size = get_ep_group ().world_size
564
- self .tp_size = get_etp_group ().world_size
565
- self .dp_size = (dp_size
566
- if dp_size is not None else get_dp_group ().world_size )
567
- self .dp_rank = (0
568
- if self .dp_size == 1 else get_dp_group ().rank_in_group )
581
+ vllm_config = get_current_vllm_config ()
582
+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
583
+ self .ep_size = get_ep_group ().world_size
584
+ self .tp_size = get_etp_group ().world_size
585
+ self .dp_size = (dp_size if dp_size is not None else
586
+ get_dp_group ().world_size )
587
+ self .dp_rank = (0 if self .dp_size == 1 else
588
+ get_dp_group ().rank_in_group )
589
+ else :
590
+ self .moe_parallel_config : FusedMoEParallelConfig = (
591
+ FusedMoEParallelConfig .make (
592
+ tp_size_ = (tp_size if tp_size is not None else
593
+ get_tensor_model_parallel_world_size ()),
594
+ dp_size_ = (dp_size if dp_size is not None else
595
+ get_dp_group ().world_size ),
596
+ vllm_parallel_config = vllm_config .parallel_config ))
597
+
598
+ self .moe_parallel_config .ep_size = get_ep_group ().world_size
569
599
570
600
self .top_k = top_k
571
601
self .num_experts = num_experts
@@ -590,27 +620,55 @@ def __init__(self,
590
620
self .local_num_experts , self .expert_map = determine_expert_map (
591
621
self .ep_size ,
592
622
get_ep_group ().rank_in_group , self .global_num_experts )
593
- self .tp_rank = get_etp_group ().rank_in_group
594
- self .ep_rank = get_ep_group ().rank_in_group
623
+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
624
+ self .tp_rank = get_etp_group ().rank_in_group
625
+ self .ep_rank = get_ep_group ().rank_in_group
626
+ else :
627
+ self .moe_parallel_config .tp_rank = get_etp_group (
628
+ ).rank_in_group
629
+ self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
630
+
595
631
else :
596
632
# Adjust TP size for DP attention
597
633
# haven't test its functionality yet, may remove in the future
598
- self .tp_rank = self .tp_size * self .dp_rank
599
- self .ep_rank = 0
600
- self .tp_size = self .tp_size * self .dp_size
601
- self .ep_size = 1
602
- self .local_num_experts = self .global_num_experts
603
- self .expert_map = None
604
-
634
+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
635
+ self .tp_rank = self .tp_size * self .dp_rank
636
+ self .ep_rank = 0
637
+ self .tp_size = self .tp_size * self .dp_size
638
+ self .ep_size = 1
639
+ else :
640
+ self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
641
+ self .moe_parallel_config .ep_rank = 0
642
+ self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
643
+ self .moe_parallel_config .ep_size = 1
644
+
645
+ self .local_num_experts , self .expert_map = (self .global_num_experts ,
646
+ None )
605
647
if self .scoring_func != "softmax" and not self .use_grouped_topk :
606
648
raise ValueError ("Only softmax scoring function is supported for "
607
649
"non-grouped topk." )
608
-
609
- if quant_config is None :
610
- self .quant_method : Optional [QuantizeMethodBase ] = (
611
- AscendUnquantizedFusedMoEMethod ())
650
+ if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
651
+ if quant_config is None :
652
+ self .quant_method : Optional [QuantizeMethodBase ] = (
653
+ AscendUnquantizedFusedMoEMethod ())
654
+ else :
655
+ self .quant_method = quant_config .get_quant_method (self , prefix )
612
656
else :
613
- self .quant_method = quant_config .get_quant_method (self , prefix )
657
+ moe = MoEConfig (
658
+ num_experts = self .global_num_experts ,
659
+ experts_per_token = top_k ,
660
+ hidden_dim = hidden_size ,
661
+ num_local_experts = self .local_num_experts ,
662
+ moe_parallel_config = self .moe_parallel_config ,
663
+ # TODO (bnell): this needs to be fixed for quantized types.
664
+ in_dtype = params_dtype ,
665
+ )
666
+
667
+ if quant_config is None :
668
+ self .quant_method = AscendUnquantizedFusedMoEMethod (moe )
669
+ else :
670
+ self .quant_method = quant_config .get_quant_method (self , prefix )
671
+
614
672
assert self .quant_method is not None
615
673
616
674
local_num_experts = torch .sum (self .expert_map != - 1 ) \
0 commit comments