3
3
4
4
import torch
5
5
6
- # TODO: use deep_gemm masked kernel after low latency dispatch
7
- # import deep_gemm
8
- # from deep_gemm import (
9
- # get_col_major_tma_aligned_tensor,
10
- # m_grouped_gemm_fp8_fp8_bf16_nt_masked,
11
- # )
6
+ try :
7
+ from deep_gemm import (
8
+ get_col_major_tma_aligned_tensor ,
9
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked ,
10
+ )
11
+
12
+ use_deep_gemm = True
13
+ except ImportError :
14
+ use_deep_gemm = False
15
+
12
16
from torch .nn import Module
13
17
14
18
from sglang .srt .custom_op import CustomOp
22
26
post_reorder_triton_kernel ,
23
27
pre_reorder_triton_kernel ,
24
28
run_moe_ep_preproess ,
29
+ silu_and_mul_masked_post_quant_fwd ,
25
30
silu_and_mul_triton_kernel ,
26
31
)
27
32
from sglang .srt .layers .moe .fused_moe_triton import FusedMoeWeightScaleSupported
@@ -809,6 +814,7 @@ def __init__(
809
814
correction_bias : Optional [torch .Tensor ] = None ,
810
815
custom_routing_function : Optional [Callable ] = None ,
811
816
activation : str = "silu" ,
817
+ deepep_mode : str = "auto" ,
812
818
):
813
819
super ().__init__ (
814
820
num_experts ,
@@ -827,21 +833,41 @@ def __init__(
827
833
custom_routing_function ,
828
834
activation ,
829
835
)
836
+ self .deepep_mode = deepep_mode
837
+ if self .deepep_mode in ["low_latency" , "auto" ]:
838
+ assert use_deep_gemm , f"DeepEP { self .deepep_mode } mode requires deep_gemm"
839
+ self .w13_weight_fp8 = (
840
+ self .w13_weight ,
841
+ (
842
+ self .w13_weight_scale_inv
843
+ if self .use_block_quant
844
+ else self .w13_weight_scale
845
+ ),
846
+ )
847
+ self .w2_weight_fp8 = (
848
+ self .w2_weight ,
849
+ self .w2_weight_scale_inv if self .use_block_quant else self .w2_weight_scale ,
850
+ )
830
851
831
852
def forward (
832
853
self ,
833
854
hidden_states : torch .Tensor ,
834
855
reorder_topk_ids : torch .Tensor ,
835
856
seg_indptr : torch .Tensor ,
857
+ masked_m : torch .Tensor ,
858
+ expected_m : int ,
836
859
forward_mode : ForwardMode ,
837
860
):
838
- # Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
839
- if True : # not forward_mode.is_decode():
861
+ if self .deepep_mode == "normal" or (
862
+ self .deepep_mode == "auto" and not forward_mode .is_decode ()
863
+ ):
840
864
return self .forward_normal (hidden_states , reorder_topk_ids , seg_indptr )
865
+ elif self .deepep_mode == "low_latency" or (
866
+ self .deepep_mode == "auto" and forward_mode .is_decode ()
867
+ ):
868
+ return self .forward_deepgemm_masked (hidden_states , masked_m , expected_m )
841
869
else :
842
- return self .forward_deepgemm_masked (
843
- hidden_states , reorder_topk_ids , seg_indptr
844
- )
870
+ raise ValueError (f"Invalid deepep_mode: { self .deepep_mode } " )
845
871
846
872
def forward_normal (
847
873
self ,
@@ -958,89 +984,66 @@ def forward_normal(
958
984
959
985
def forward_deepgemm_masked (
960
986
self ,
961
- hidden_states : torch .Tensor ,
962
- reorder_topk_ids : torch .Tensor ,
963
- seg_indptr : torch . Tensor ,
987
+ hidden_states_fp8 : Tuple [ torch .Tensor , torch . Tensor ] ,
988
+ masked_m : torch .Tensor ,
989
+ expected_m : int ,
964
990
):
965
991
assert self .quant_method is not None
966
992
assert self .activation == "silu"
967
-
968
- if self .activation_scheme == "dynamic" and not self .use_block_quant :
969
- max_value = (
970
- torch .max (hidden_states )
971
- .repeat (self .num_experts_per_partition )
972
- .to (torch .float32 )
973
- )
974
- self .w13_input_scale = max_value / torch .finfo (self .fp8_dtype ).max
993
+ assert (
994
+ hidden_states_fp8 [0 ].size (0 ) % 4 == 0
995
+ ), f"TMA alignment error: { hidden_states_fp8 [0 ].size (0 )} "
975
996
976
997
# GroupGemm-0
998
+ num_groups , m , k = hidden_states_fp8 [0 ].size ()
999
+ n = self .w13_weight .size (1 )
1000
+ expected_m = min (expected_m , m )
977
1001
gateup_output = torch .empty (
978
- hidden_states . shape [0 ],
979
- self . w13_weight . shape [ 1 ],
980
- device = hidden_states . device ,
981
- dtype = hidden_states . dtype ,
1002
+ ( num_groups , m , n ), device = hidden_states_fp8 [0 ]. device , dtype = torch . bfloat16
1003
+ )
1004
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked (
1005
+ hidden_states_fp8 , self . w13_weight_fp8 , gateup_output , masked_m , expected_m
982
1006
)
983
- if hidden_states .shape [0 ] > 0 :
984
- # Transpose earlier so that the testing will not trigger transposing kernels
985
- hidden_states = (
986
- hidden_states [0 ],
987
- get_col_major_tma_aligned_tensor (hidden_states [1 ]),
988
- )
989
- """
990
- gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
991
- hidden_states, self.w13_weight, out, masked_m, expected_m
992
- )
993
- """
994
1007
995
1008
# Act
996
1009
down_input = torch .empty (
997
- gateup_output .shape [0 ],
998
- gateup_output .shape [1 ] // 2 ,
999
- device = gateup_output .device ,
1000
- dtype = (
1001
- self .fp8_dtype
1002
- if (self .use_fp8_w8a8 and not self .use_block_quant )
1003
- else hidden_states .dtype
1010
+ (
1011
+ gateup_output .shape [0 ],
1012
+ gateup_output .shape [1 ],
1013
+ gateup_output .shape [2 ] // 2 ,
1004
1014
),
1015
+ device = gateup_output .device ,
1016
+ dtype = self .fp8_dtype ,
1005
1017
)
1006
- if self .w2_input_scale is None and not self .use_block_quant :
1007
- self .w2_input_scale = torch .ones (
1008
- self .num_experts_per_partition ,
1009
- dtype = torch .float32 ,
1010
- device = hidden_states .device ,
1011
- )
1012
-
1013
- if self .activation == "silu" :
1014
- silu_and_mul_triton_kernel [(gateup_output .shape [0 ],)](
1015
- gateup_output ,
1016
- down_input ,
1018
+ scale_block_size = 128
1019
+ down_input_scale = torch .empty (
1020
+ (
1021
+ gateup_output .shape [0 ],
1017
1022
gateup_output .shape [1 ],
1018
- reorder_topk_ids ,
1019
- self .w2_input_scale ,
1020
- 0 ,
1021
- self .num_experts_per_partition - 1 ,
1022
- BLOCK_SIZE = 512 ,
1023
- )
1024
- else :
1025
- raise ValueError (f"Unsupported activation: { self .activation = } " )
1023
+ gateup_output .shape [2 ] // 2 // scale_block_size ,
1024
+ ),
1025
+ device = gateup_output .device ,
1026
+ dtype = torch .float32 ,
1027
+ )
1028
+ silu_and_mul_masked_post_quant_fwd (
1029
+ gateup_output ,
1030
+ down_input ,
1031
+ down_input_scale ,
1032
+ scale_block_size ,
1033
+ masked_m ,
1034
+ )
1026
1035
1027
1036
# GroupGemm-1
1037
+ n = self .w2_weight .size (1 )
1038
+ down_input_fp8 = (
1039
+ down_input ,
1040
+ get_col_major_tma_aligned_tensor (down_input_scale ),
1041
+ )
1028
1042
down_output = torch .empty (
1029
- down_input .shape [ 0 ],
1030
- self . w2_weight . shape [ 1 ],
1031
- device = hidden_states . device ,
1032
- dtype = hidden_states . dtype ,
1043
+ ( num_groups , m , n ), device = down_input .device , dtype = torch . bfloat16
1044
+ )
1045
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked (
1046
+ down_input_fp8 , self . w2_weight_fp8 , down_output , masked_m , expected_m
1033
1047
)
1034
- if down_input .shape [0 ] > 0 :
1035
- # Transpose earlier so that the testing will not trigger transposing kernels
1036
- down_input = (
1037
- down_input [0 ],
1038
- get_col_major_tma_aligned_tensor (down_input [1 ]),
1039
- )
1040
- """
1041
- down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1042
- down_input, self.w2_weight, out, masked_m, expected_m
1043
- )
1044
- """
1045
1048
1046
1049
return down_output
0 commit comments