|
5 | 5 | import torch
|
6 | 6 |
|
7 | 7 | from vllm import _custom_ops as ops
|
8 |
| -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase |
| 8 | +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, |
| 9 | + FusedMoeWeightScaleSupported) |
9 | 10 | from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
10 | 11 | WNA16_SUPPORTED_BITS)
|
11 | 12 | from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
12 |
| - CompressionFormat) |
| 13 | + CompressionFormat, QuantizationStrategy) |
| 14 | +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( |
| 15 | + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) |
13 | 16 | from vllm.model_executor.utils import set_weight_attrs
|
| 17 | +from vllm.utils import is_hip, print_warning_once |
14 | 18 |
|
15 | 19 |
|
16 | 20 | class GPTQMarlinState(Enum):
|
17 | 21 | REPACK = enum.auto()
|
18 | 22 | READY = enum.auto()
|
19 | 23 |
|
20 | 24 |
|
21 |
| -__all__ = ["CompressedTensorsMoEMethod"] |
| 25 | +__all__ = [ |
| 26 | + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", |
| 27 | + "CompressedTensorsWNA16MoEMethod" |
| 28 | +] |
22 | 29 |
|
23 | 30 |
|
24 | 31 | class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
25 | 32 |
|
| 33 | + @staticmethod |
| 34 | + def get_moe_method( |
| 35 | + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 |
| 36 | + ) -> "CompressedTensorsMoEMethod": |
| 37 | + # TODO: @dsikka: refactor this to use schemes as other kernels |
| 38 | + # are supported + check if the layer is being ignored. |
| 39 | + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") |
| 40 | + input_quant = quant_config.target_scheme_map["Linear"].get( |
| 41 | + "input_activations") |
| 42 | + |
| 43 | + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): |
| 44 | + return CompressedTensorsWNA16MoEMethod(quant_config) |
| 45 | + elif quant_config._is_fp8_w8a8(weight_quant, input_quant): |
| 46 | + return CompressedTensorsW8A8Fp8MoEMethod(quant_config) |
| 47 | + else: |
| 48 | + raise RuntimeError( |
| 49 | + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") |
| 50 | + |
| 51 | + |
| 52 | +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): |
| 53 | + |
| 54 | + def __init__( |
| 55 | + self, |
| 56 | + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 |
| 57 | + ): |
| 58 | + self.quant_config = quant_config |
| 59 | + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( |
| 60 | + "weights") |
| 61 | + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( |
| 62 | + "input_activations") |
| 63 | + |
| 64 | + if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR |
| 65 | + and self.input_quant.strategy == QuantizationStrategy.TENSOR): |
| 66 | + raise ValueError( |
| 67 | + "For FP8 Fused MoE layers, only per-tensor scales" |
| 68 | + "for weights and activations are supported. Found " |
| 69 | + f"{self.weight_quant}, {self.input_quant}") |
| 70 | + |
| 71 | + self.static_input_scales = not self.input_quant.dynamic |
| 72 | + |
| 73 | + def create_weights(self, layer: torch.nn.Module, num_experts: int, |
| 74 | + hidden_size: int, intermediate_size: int, |
| 75 | + params_dtype: torch.dtype, **extra_weight_attrs): |
| 76 | + |
| 77 | + params_dtype = torch.float8_e4m3fn |
| 78 | + |
| 79 | + # WEIGHTS |
| 80 | + w13_weight = torch.nn.Parameter(torch.empty(num_experts, |
| 81 | + 2 * intermediate_size, |
| 82 | + hidden_size, |
| 83 | + dtype=params_dtype), |
| 84 | + requires_grad=False) |
| 85 | + layer.register_parameter("w13_weight", w13_weight) |
| 86 | + set_weight_attrs(w13_weight, extra_weight_attrs) |
| 87 | + |
| 88 | + w2_weight = torch.nn.Parameter(torch.empty(num_experts, |
| 89 | + hidden_size, |
| 90 | + intermediate_size, |
| 91 | + dtype=params_dtype), |
| 92 | + requires_grad=False) |
| 93 | + layer.register_parameter("w2_weight", w2_weight) |
| 94 | + set_weight_attrs(w2_weight, extra_weight_attrs) |
| 95 | + |
| 96 | + # WEIGHT_SCALES |
| 97 | + # Allocate 2 scales for w1 and w3 respectively. |
| 98 | + # They will be combined to a single scale after weight loading. |
| 99 | + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, |
| 100 | + 2, |
| 101 | + dtype=torch.float32), |
| 102 | + requires_grad=False) |
| 103 | + layer.register_parameter("w13_weight_scale", w13_weight_scale) |
| 104 | + |
| 105 | + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, |
| 106 | + dtype=torch.float32), |
| 107 | + requires_grad=False) |
| 108 | + layer.register_parameter("w2_weight_scale", w2_weight_scale) |
| 109 | + # Add the quantization method used (per tensor/grouped/channel) |
| 110 | + # to ensure the weight scales are loaded in properly |
| 111 | + extra_weight_attrs.update( |
| 112 | + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) |
| 113 | + set_weight_attrs(w13_weight_scale, extra_weight_attrs) |
| 114 | + set_weight_attrs(w2_weight_scale, extra_weight_attrs) |
| 115 | + |
| 116 | + # INPUT_SCALES |
| 117 | + if self.static_input_scales: |
| 118 | + w13_input_scale = torch.nn.Parameter(torch.ones( |
| 119 | + num_experts, dtype=torch.float32), |
| 120 | + requires_grad=False) |
| 121 | + layer.register_parameter("w13_input_scale", w13_input_scale) |
| 122 | + set_weight_attrs(w13_input_scale, extra_weight_attrs) |
| 123 | + |
| 124 | + w2_input_scale = torch.nn.Parameter(torch.ones( |
| 125 | + num_experts, dtype=torch.float32), |
| 126 | + requires_grad=False) |
| 127 | + layer.register_parameter("w2_input_scale", w2_input_scale) |
| 128 | + set_weight_attrs(w2_input_scale, extra_weight_attrs) |
| 129 | + else: |
| 130 | + layer.w13_input_scale = None |
| 131 | + layer.w2_input_scale = None |
| 132 | + |
| 133 | + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 134 | + # Fp8 moe kernels require a single activation scale. |
| 135 | + # We take the max of all the scales in case they differ. |
| 136 | + if self.static_input_scales: |
| 137 | + if (layer.w13_input_scale is None or layer.w2_input_scale is None): |
| 138 | + raise ValueError( |
| 139 | + "QuantConfig has static quantization, but found " |
| 140 | + "activation scales are None.") |
| 141 | + if (not all_close_1d(layer.w13_input_scale) |
| 142 | + or not all_close_1d(layer.w2_input_scale)): |
| 143 | + print_warning_once( |
| 144 | + "Found input_scales that are not equal for " |
| 145 | + "fp8 MoE layer. Using the maximum across experts " |
| 146 | + "for each layer. ") |
| 147 | + layer.w13_input_scale = torch.nn.Parameter( |
| 148 | + layer.w13_input_scale.max(), requires_grad=False) |
| 149 | + layer.w2_input_scale = torch.nn.Parameter( |
| 150 | + layer.w2_input_scale.max(), requires_grad=False) |
| 151 | + |
| 152 | + # If rocm, normalize the weights and scales to e4m3fnuz |
| 153 | + if is_hip(): |
| 154 | + # Normalize the weights and scales |
| 155 | + w13_weight, w13_weight_scale, w13_input_scale = \ |
| 156 | + normalize_e4m3fn_to_e4m3fnuz( |
| 157 | + layer.w13_weight, layer.w13_weight_scale, |
| 158 | + layer.w13_input_scale) |
| 159 | + w2_weight, w2_weight_scale, w2_input_scale = \ |
| 160 | + normalize_e4m3fn_to_e4m3fnuz( |
| 161 | + layer.w2_weight, layer.w2_weight_scale, |
| 162 | + layer.w2_input_scale) |
| 163 | + # Reset the parameter |
| 164 | + layer.w13_weight = torch.nn.Parameter(w13_weight, |
| 165 | + requires_grad=False) |
| 166 | + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, |
| 167 | + requires_grad=False) |
| 168 | + if w13_input_scale is not None: |
| 169 | + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, |
| 170 | + requires_grad=False) |
| 171 | + layer.w2_weight = torch.nn.Parameter(w2_weight, |
| 172 | + requires_grad=False) |
| 173 | + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, |
| 174 | + requires_grad=False) |
| 175 | + if w2_input_scale is not None: |
| 176 | + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, |
| 177 | + requires_grad=False) |
| 178 | + |
| 179 | + # Fp8 moe kernel needs single weight scale for w13 per expert. |
| 180 | + # We take the max then dequant and requant each expert. |
| 181 | + assert layer.w13_weight_scale is not None |
| 182 | + shard_size = layer.intermediate_size_per_partition |
| 183 | + max_w13_scales = layer.w13_weight_scale.max(dim=1).values |
| 184 | + for expert_id in range(layer.num_experts): |
| 185 | + start = 0 |
| 186 | + for shard_id in range(2): |
| 187 | + dq_weight = per_tensor_dequantize( |
| 188 | + layer.w13_weight[expert_id][start:start + shard_size, :], |
| 189 | + layer.w13_weight_scale[expert_id][shard_id]) |
| 190 | + layer.w13_weight[expert_id][ |
| 191 | + start:start + shard_size, :], _ = ops.scaled_fp8_quant( |
| 192 | + dq_weight, max_w13_scales[expert_id]) |
| 193 | + start += shard_size |
| 194 | + |
| 195 | + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, |
| 196 | + requires_grad=False) |
| 197 | + |
| 198 | + def apply( |
| 199 | + self, |
| 200 | + layer: torch.nn.Module, |
| 201 | + x: torch.Tensor, |
| 202 | + router_logits: torch.Tensor, |
| 203 | + top_k: int, |
| 204 | + renormalize: bool = True, |
| 205 | + use_grouped_topk: bool = False, |
| 206 | + num_expert_group: Optional[int] = None, |
| 207 | + topk_group: Optional[int] = None, |
| 208 | + custom_routing_function: Optional[Callable] = None, |
| 209 | + ) -> torch.Tensor: |
| 210 | + |
| 211 | + from vllm.model_executor.layers.fused_moe import fused_experts |
| 212 | + |
| 213 | + topk_weights, topk_ids = FusedMoE.select_experts( |
| 214 | + hidden_states=x, |
| 215 | + router_logits=router_logits, |
| 216 | + use_grouped_topk=use_grouped_topk, |
| 217 | + top_k=top_k, |
| 218 | + renormalize=renormalize, |
| 219 | + topk_group=topk_group, |
| 220 | + num_expert_group=num_expert_group, |
| 221 | + custom_routing_function=custom_routing_function) |
| 222 | + |
| 223 | + return fused_experts(x, |
| 224 | + layer.w13_weight, |
| 225 | + layer.w2_weight, |
| 226 | + topk_weights=topk_weights, |
| 227 | + topk_ids=topk_ids, |
| 228 | + inplace=True, |
| 229 | + use_fp8_w8a8=True, |
| 230 | + w1_scale=layer.w13_weight_scale, |
| 231 | + w2_scale=layer.w2_weight_scale, |
| 232 | + a1_scale=layer.w13_input_scale, |
| 233 | + a2_scale=layer.w2_input_scale) |
| 234 | + |
| 235 | + |
| 236 | +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| 237 | + |
26 | 238 | def __init__(
|
27 | 239 | self,
|
28 | 240 | quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
|
0 commit comments