Skip to content

Commit 57651c7

Browse files
mgoinmawong-amd
authored andcommitted
Update CT WNA16MarlinMoE integration (vllm-project#16666)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 2ef4bb8 commit 57651c7

File tree

1 file changed

+38
-81
lines changed

1 file changed

+38
-81
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 38 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import enum
44
from enum import Enum
5-
from typing import Callable, List, Optional
5+
from typing import Callable, Optional
66

77
import torch
88
from compressed_tensors import CompressionFormat
@@ -14,9 +14,12 @@
1414
from vllm.logger import init_logger
1515
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
1616
FusedMoeWeightScaleSupported)
17-
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
18-
WNA16_SUPPORTED_BITS)
17+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
18+
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
1919
from vllm.model_executor.layers.quantization.utils import replace_parameter
20+
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
21+
check_moe_marlin_supports_layer, marlin_make_workspace_new,
22+
marlin_moe_permute_scales)
2023
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
2124
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
2225
from vllm.model_executor.utils import set_weight_attrs
@@ -54,18 +57,19 @@ def get_moe_method(
5457
"input_activations")
5558

5659
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
57-
# Prefer to use the non-marlin kernel when:
58-
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
59-
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
60-
# 3. Actorder is not group/dynamic (g_idx is unsupported)
61-
# 4. Scaled are grouped (channelwise is unsupported)
62-
if ((layer.local_num_experts >= 16
63-
or layer.params_dtype != torch.float16) and
64-
weight_quant.actorder not in (ActivationOrdering.GROUP,
65-
ActivationOrdering.DYNAMIC)
66-
and weight_quant.strategy in QuantizationStrategy.GROUP):
60+
# Prefer to use the MarlinMoE kernel when it is supported.
61+
if not check_moe_marlin_supports_layer(layer,
62+
weight_quant.group_size):
63+
if (weight_quant.strategy in QuantizationStrategy.GROUP and
64+
weight_quant.actorder in (ActivationOrdering.GROUP,
65+
ActivationOrdering.DYNAMIC)):
66+
raise ValueError(
67+
"WNA16MoE is not supported with actorder=group/dynamic."
68+
)
69+
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
6770
return CompressedTensorsWNA16MoEMethod(quant_config)
6871
else:
72+
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
6973
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
7074
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
7175
and layer.activation == "silu"):
@@ -705,15 +709,12 @@ def __init__(
705709
f"{CompressionFormat.pack_quantized.value} ",
706710
"is supported for the following bits: ",
707711
f"{WNA16_SUPPORTED_BITS}")
712+
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
708713

709714
def create_weights(self, layer: torch.nn.Module, num_experts: int,
710715
hidden_size: int, intermediate_size_per_partition: int,
711716
params_dtype: torch.dtype, **extra_weight_attrs):
712717

713-
assert params_dtype == torch.float16, (
714-
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
715-
)
716-
717718
intermediate_size_full = extra_weight_attrs.pop(
718719
"intermediate_size_full")
719720

@@ -837,50 +838,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
837838
layer.marlin_state = GPTQMarlinState.REPACK
838839

839840
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
840-
841-
def replace_tensor(name, new_t):
842-
# It is important to use resize_() here since it ensures
843-
# the same buffer is reused
844-
getattr(layer, name).resize_(new_t.shape)
845-
getattr(layer, name).copy_(new_t)
846-
del new_t
847-
848-
def get_scale_perms(num_bits: int):
849-
scale_perm: List[int] = []
850-
for i in range(8):
851-
scale_perm.extend([i + 8 * j for j in range(8)])
852-
scale_perm_single: List[int] = []
853-
for i in range(4):
854-
scale_perm_single.extend(
855-
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
856-
return scale_perm, scale_perm_single
857-
858-
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
859-
group_size: int, num_bits: int):
860-
scale_perm, scale_perm_single = get_scale_perms(num_bits)
861-
if group_size < size_k and group_size != -1:
862-
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
863-
else:
864-
s = s.reshape((-1, len(scale_perm_single)))[:,
865-
scale_perm_single]
866-
s = s.reshape((-1, size_n)).contiguous()
867-
return s
868-
869-
def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
870-
size_n: int, group_size: int,
871-
num_bits: int):
872-
num_experts = s.shape[0]
873-
output = torch.empty((num_experts, s.shape[1], s.shape[2]),
874-
device=s.device,
875-
dtype=s.dtype)
876-
for e in range(num_experts):
877-
output[e] = marlin_permute_scales(s[e], size_k, size_n,
878-
group_size, num_bits)
879-
return output
880-
881-
size_k2 = layer.w2_weight_packed.shape[2]
882-
size_k13 = layer.w13_weight_packed.shape[2]
883-
884841
num_experts = layer.w13_weight_g_idx.shape[0]
885842
device = layer.w13_weight_g_idx.device
886843

@@ -938,33 +895,33 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
938895
layer.w13_weight_packed.shape[2],
939896
self.num_bits,
940897
)
941-
replace_tensor("w13_weight_packed", marlin_w13_qweight)
898+
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
942899
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
943900
layer.w2_weight_packed,
944901
layer.w2_g_idx_sort_indices,
945902
layer.w2_weight_packed.shape[1] * self.packed_factor,
946903
layer.w2_weight_packed.shape[2],
947904
self.num_bits,
948905
)
949-
replace_tensor("w2_weight_packed", marlin_w2_qweight)
906+
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
950907
# Repack scales
951908
marlin_w13_scales = marlin_moe_permute_scales(
952-
layer.w13_weight_scale,
953-
size_k13,
954-
layer.w13_weight_scale.shape[2],
955-
self.group_size,
956-
self.num_bits,
909+
s=layer.w13_weight_scale,
910+
size_k=layer.w13_weight_packed.shape[2],
911+
size_n=layer.w13_weight_scale.shape[2],
912+
group_size=self.group_size,
957913
)
958-
replace_tensor("w13_weight_scale", marlin_w13_scales)
914+
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
959915
marlin_w2_scales = marlin_moe_permute_scales(
960-
layer.w2_weight_scale,
961-
layer.w2_weight_scale.shape[1] *
916+
s=layer.w2_weight_scale,
917+
size_k=layer.w2_weight_scale.shape[1] *
962918
(self.group_size if self.group_size != -1 else self.packed_factor),
963-
size_k2,
964-
self.group_size,
965-
self.num_bits,
919+
size_n=layer.w2_weight_scale.shape[2],
920+
group_size=self.group_size,
966921
)
967-
replace_tensor("w2_weight_scale", marlin_w2_scales)
922+
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
923+
924+
layer.workspace = marlin_make_workspace_new(device, 4)
968925

969926
def apply(
970927
self,
@@ -985,10 +942,6 @@ def apply(
985942
activation: str = "silu",
986943
) -> torch.Tensor:
987944
assert activation == "silu", "Only SiLU activation is supported."
988-
if expert_map is not None:
989-
raise NotImplementedError(
990-
"Expert Parallelism is not supported for "
991-
"fused Marlin MoE method.")
992945
if apply_router_weight_on_input:
993946
raise NotImplementedError(
994947
"Apply router weight on input is not supported for "
@@ -1015,11 +968,14 @@ def apply(
1015968
router_logits,
1016969
topk_weights,
1017970
topk_ids,
971+
quant_type_id=self.quant_type.id,
972+
global_num_experts=global_num_experts,
973+
expert_map=expert_map,
1018974
g_idx1=layer.w13_weight_g_idx,
1019975
g_idx2=layer.w2_weight_g_idx,
1020976
sort_indices1=layer.w13_g_idx_sort_indices,
1021977
sort_indices2=layer.w2_g_idx_sort_indices,
1022-
num_bits=self.num_bits,
978+
workspace=layer.workspace,
1023979
is_k_full=self.is_k_full)
1024980

1025981

@@ -1203,7 +1159,7 @@ def apply(
12031159
activation: str = "silu",
12041160
) -> torch.Tensor:
12051161
from vllm.model_executor.layers.fused_moe import fused_experts
1206-
assert activation == "silu", "Only SiLU activation is supported."
1162+
12071163
topk_weights, topk_ids = FusedMoE.select_experts(
12081164
hidden_states=x,
12091165
router_logits=router_logits,
@@ -1223,6 +1179,7 @@ def apply(
12231179
topk_weights=topk_weights,
12241180
topk_ids=topk_ids,
12251181
inplace=True,
1182+
activation=activation,
12261183
use_int4_w4a16=self.num_bits == 4,
12271184
use_int8_w8a16=self.num_bits == 8,
12281185
global_num_experts=global_num_experts,

0 commit comments

Comments
 (0)