2
2
3
3
import enum
4
4
from enum import Enum
5
- from typing import Callable , List , Optional
5
+ from typing import Callable , Optional
6
6
7
7
import torch
8
8
from compressed_tensors import CompressionFormat
14
14
from vllm .logger import init_logger
15
15
from vllm .model_executor .layers .fused_moe import (FusedMoE , FusedMoEMethodBase ,
16
16
FusedMoeWeightScaleSupported )
17
- from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
18
- WNA16_SUPPORTED_BITS )
17
+ from vllm .model_executor .layers .quantization .compressed_tensors .schemes . compressed_tensors_wNa16 import ( # noqa
18
+ WNA16_SUPPORTED_BITS , WNA16_SUPPORTED_TYPES_MAP )
19
19
from vllm .model_executor .layers .quantization .utils import replace_parameter
20
+ from vllm .model_executor .layers .quantization .utils .marlin_utils import (
21
+ check_moe_marlin_supports_layer , marlin_make_workspace_new ,
22
+ marlin_moe_permute_scales )
20
23
from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
21
24
all_close_1d , normalize_e4m3fn_to_e4m3fnuz , per_tensor_dequantize )
22
25
from vllm .model_executor .utils import set_weight_attrs
@@ -54,18 +57,19 @@ def get_moe_method(
54
57
"input_activations" )
55
58
56
59
if quant_config ._is_wNa16_group_channel (weight_quant , input_quant ):
57
- # Prefer to use the non-marlin kernel when:
58
- # 1. Many experts (MarlinMoE gives poor performance when >= 16)
59
- # 2. Non-FP16 dtype (MarlinMoE only supports FP16)
60
- # 3. Actorder is not group/dynamic (g_idx is unsupported)
61
- # 4. Scaled are grouped (channelwise is unsupported)
62
- if (( layer . local_num_experts >= 16
63
- or layer . params_dtype != torch . float16 ) and
64
- weight_quant . actorder not in ( ActivationOrdering . GROUP ,
65
- ActivationOrdering . DYNAMIC )
66
- and weight_quant . strategy in QuantizationStrategy . GROUP ):
60
+ # Prefer to use the MarlinMoE kernel when it is supported.
61
+ if not check_moe_marlin_supports_layer ( layer ,
62
+ weight_quant . group_size ):
63
+ if ( weight_quant . strategy in QuantizationStrategy . GROUP and
64
+ weight_quant . actorder in ( ActivationOrdering . GROUP ,
65
+ ActivationOrdering . DYNAMIC )):
66
+ raise ValueError (
67
+ "WNA16MoE is not supported with actorder=group/dynamic."
68
+ )
69
+ logger . info_once ( "Using CompressedTensorsWNA16MoEMethod" )
67
70
return CompressedTensorsWNA16MoEMethod (quant_config )
68
71
else :
72
+ logger .info_once ("Using CompressedTensorsWNA16MarlinMoEMethod" )
69
73
return CompressedTensorsWNA16MarlinMoEMethod (quant_config )
70
74
elif (quant_config ._is_fp8_w8a8_sm90 (weight_quant , input_quant )
71
75
and layer .activation == "silu" ):
@@ -705,15 +709,12 @@ def __init__(
705
709
f"{ CompressionFormat .pack_quantized .value } " ,
706
710
"is supported for the following bits: " ,
707
711
f"{ WNA16_SUPPORTED_BITS } " )
712
+ self .quant_type = WNA16_SUPPORTED_TYPES_MAP [self .num_bits ]
708
713
709
714
def create_weights (self , layer : torch .nn .Module , num_experts : int ,
710
715
hidden_size : int , intermediate_size_per_partition : int ,
711
716
params_dtype : torch .dtype , ** extra_weight_attrs ):
712
717
713
- assert params_dtype == torch .float16 , (
714
- "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
715
- )
716
-
717
718
intermediate_size_full = extra_weight_attrs .pop (
718
719
"intermediate_size_full" )
719
720
@@ -837,50 +838,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
837
838
layer .marlin_state = GPTQMarlinState .REPACK
838
839
839
840
def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
840
-
841
- def replace_tensor (name , new_t ):
842
- # It is important to use resize_() here since it ensures
843
- # the same buffer is reused
844
- getattr (layer , name ).resize_ (new_t .shape )
845
- getattr (layer , name ).copy_ (new_t )
846
- del new_t
847
-
848
- def get_scale_perms (num_bits : int ):
849
- scale_perm : List [int ] = []
850
- for i in range (8 ):
851
- scale_perm .extend ([i + 8 * j for j in range (8 )])
852
- scale_perm_single : List [int ] = []
853
- for i in range (4 ):
854
- scale_perm_single .extend (
855
- [2 * i + j for j in [0 , 1 , 8 , 9 , 16 , 17 , 24 , 25 ]])
856
- return scale_perm , scale_perm_single
857
-
858
- def marlin_permute_scales (s : torch .Tensor , size_k : int , size_n : int ,
859
- group_size : int , num_bits : int ):
860
- scale_perm , scale_perm_single = get_scale_perms (num_bits )
861
- if group_size < size_k and group_size != - 1 :
862
- s = s .reshape ((- 1 , len (scale_perm )))[:, scale_perm ]
863
- else :
864
- s = s .reshape ((- 1 , len (scale_perm_single )))[:,
865
- scale_perm_single ]
866
- s = s .reshape ((- 1 , size_n )).contiguous ()
867
- return s
868
-
869
- def marlin_moe_permute_scales (s : torch .Tensor , size_k : int ,
870
- size_n : int , group_size : int ,
871
- num_bits : int ):
872
- num_experts = s .shape [0 ]
873
- output = torch .empty ((num_experts , s .shape [1 ], s .shape [2 ]),
874
- device = s .device ,
875
- dtype = s .dtype )
876
- for e in range (num_experts ):
877
- output [e ] = marlin_permute_scales (s [e ], size_k , size_n ,
878
- group_size , num_bits )
879
- return output
880
-
881
- size_k2 = layer .w2_weight_packed .shape [2 ]
882
- size_k13 = layer .w13_weight_packed .shape [2 ]
883
-
884
841
num_experts = layer .w13_weight_g_idx .shape [0 ]
885
842
device = layer .w13_weight_g_idx .device
886
843
@@ -938,33 +895,33 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
938
895
layer .w13_weight_packed .shape [2 ],
939
896
self .num_bits ,
940
897
)
941
- replace_tensor ( "w13_weight_packed" , marlin_w13_qweight )
898
+ replace_parameter ( layer , "w13_weight_packed" , marlin_w13_qweight )
942
899
marlin_w2_qweight = ops .gptq_marlin_moe_repack (
943
900
layer .w2_weight_packed ,
944
901
layer .w2_g_idx_sort_indices ,
945
902
layer .w2_weight_packed .shape [1 ] * self .packed_factor ,
946
903
layer .w2_weight_packed .shape [2 ],
947
904
self .num_bits ,
948
905
)
949
- replace_tensor ( "w2_weight_packed" , marlin_w2_qweight )
906
+ replace_parameter ( layer , "w2_weight_packed" , marlin_w2_qweight )
950
907
# Repack scales
951
908
marlin_w13_scales = marlin_moe_permute_scales (
952
- layer .w13_weight_scale ,
953
- size_k13 ,
954
- layer .w13_weight_scale .shape [2 ],
955
- self .group_size ,
956
- self .num_bits ,
909
+ s = layer .w13_weight_scale ,
910
+ size_k = layer .w13_weight_packed .shape [2 ],
911
+ size_n = layer .w13_weight_scale .shape [2 ],
912
+ group_size = self .group_size ,
957
913
)
958
- replace_tensor ( "w13_weight_scale" , marlin_w13_scales )
914
+ replace_parameter ( layer , "w13_weight_scale" , marlin_w13_scales )
959
915
marlin_w2_scales = marlin_moe_permute_scales (
960
- layer .w2_weight_scale ,
961
- layer .w2_weight_scale .shape [1 ] *
916
+ s = layer .w2_weight_scale ,
917
+ size_k = layer .w2_weight_scale .shape [1 ] *
962
918
(self .group_size if self .group_size != - 1 else self .packed_factor ),
963
- size_k2 ,
964
- self .group_size ,
965
- self .num_bits ,
919
+ size_n = layer .w2_weight_scale .shape [2 ],
920
+ group_size = self .group_size ,
966
921
)
967
- replace_tensor ("w2_weight_scale" , marlin_w2_scales )
922
+ replace_parameter (layer , "w2_weight_scale" , marlin_w2_scales )
923
+
924
+ layer .workspace = marlin_make_workspace_new (device , 4 )
968
925
969
926
def apply (
970
927
self ,
@@ -985,10 +942,6 @@ def apply(
985
942
activation : str = "silu" ,
986
943
) -> torch .Tensor :
987
944
assert activation == "silu" , "Only SiLU activation is supported."
988
- if expert_map is not None :
989
- raise NotImplementedError (
990
- "Expert Parallelism is not supported for "
991
- "fused Marlin MoE method." )
992
945
if apply_router_weight_on_input :
993
946
raise NotImplementedError (
994
947
"Apply router weight on input is not supported for "
@@ -1015,11 +968,14 @@ def apply(
1015
968
router_logits ,
1016
969
topk_weights ,
1017
970
topk_ids ,
971
+ quant_type_id = self .quant_type .id ,
972
+ global_num_experts = global_num_experts ,
973
+ expert_map = expert_map ,
1018
974
g_idx1 = layer .w13_weight_g_idx ,
1019
975
g_idx2 = layer .w2_weight_g_idx ,
1020
976
sort_indices1 = layer .w13_g_idx_sort_indices ,
1021
977
sort_indices2 = layer .w2_g_idx_sort_indices ,
1022
- num_bits = self . num_bits ,
978
+ workspace = layer . workspace ,
1023
979
is_k_full = self .is_k_full )
1024
980
1025
981
@@ -1203,7 +1159,7 @@ def apply(
1203
1159
activation : str = "silu" ,
1204
1160
) -> torch .Tensor :
1205
1161
from vllm .model_executor .layers .fused_moe import fused_experts
1206
- assert activation == "silu" , "Only SiLU activation is supported."
1162
+
1207
1163
topk_weights , topk_ids = FusedMoE .select_experts (
1208
1164
hidden_states = x ,
1209
1165
router_logits = router_logits ,
@@ -1223,6 +1179,7 @@ def apply(
1223
1179
topk_weights = topk_weights ,
1224
1180
topk_ids = topk_ids ,
1225
1181
inplace = True ,
1182
+ activation = activation ,
1226
1183
use_int4_w4a16 = self .num_bits == 4 ,
1227
1184
use_int8_w8a16 = self .num_bits == 8 ,
1228
1185
global_num_experts = global_num_experts ,
0 commit comments