Skip to content

Commit 4029180

Browse files
mgoinLeiWang1999
authored andcommitted
[Misc] Support FP8 MoE for compressed-tensors (vllm-project#8588)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent f61c984 commit 4029180

File tree

5 files changed

+226
-8
lines changed

5 files changed

+226
-8
lines changed

tests/weight_loading/models-large.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
22
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
33
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
4+
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
45
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,12 @@ def weight_loader(self, param: torch.nn.Parameter,
323323
loaded_weight: torch.Tensor, weight_name: str,
324324
shard_id: str, expert_id: int) -> None:
325325

326-
# compressed-tensors represents weights on disk which are flipped
326+
# compressed-tensors checkpoints with packed weights are stored flipped
327+
# TODO (mgoin): check self.quant_method.quant_config.quant_format
328+
# against known CompressionFormat enum values that have this quality
327329
loaded_weight = loaded_weight.t().contiguous() if (
328330
self.quant_method.__class__.__name__
329-
== "CompressedTensorsMoEMethod") else loaded_weight
331+
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
330332

331333
if shard_id not in ("w1", "w2", "w3"):
332334
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
@@ -353,6 +355,9 @@ def weight_loader(self, param: torch.nn.Parameter,
353355

354356
# Case input scale: input_scale loading is only supported for fp8
355357
if "input_scale" in weight_name:
358+
# this is needed for compressed-tensors only
359+
loaded_weight = loaded_weight.to(param.data.device)
360+
356361
if param.data[expert_id] != 1 and (param.data[expert_id] -
357362
loaded_weight).abs() > 1e-5:
358363
raise ValueError(

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_quant_method(
7373
if isinstance(layer, Attention):
7474
return CompressedTensorsKVCacheMethod(self)
7575
if isinstance(layer, FusedMoE):
76-
return CompressedTensorsMoEMethod(self)
76+
return CompressedTensorsMoEMethod.get_moe_method(self)
7777
return None
7878

7979
@classmethod

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 215 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,236 @@
55
import torch
66

77
from vllm import _custom_ops as ops
8-
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
8+
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
9+
FusedMoeWeightScaleSupported)
910
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1011
WNA16_SUPPORTED_BITS)
1112
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
12-
CompressionFormat)
13+
CompressionFormat, QuantizationStrategy)
14+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
15+
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
1316
from vllm.model_executor.utils import set_weight_attrs
17+
from vllm.utils import is_hip, print_warning_once
1418

1519

1620
class GPTQMarlinState(Enum):
1721
REPACK = enum.auto()
1822
READY = enum.auto()
1923

2024

21-
__all__ = ["CompressedTensorsMoEMethod"]
25+
__all__ = [
26+
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
27+
"CompressedTensorsWNA16MoEMethod"
28+
]
2229

2330

2431
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
2532

33+
@staticmethod
34+
def get_moe_method(
35+
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
36+
) -> "CompressedTensorsMoEMethod":
37+
# TODO: @dsikka: refactor this to use schemes as other kernels
38+
# are supported + check if the layer is being ignored.
39+
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
40+
input_quant = quant_config.target_scheme_map["Linear"].get(
41+
"input_activations")
42+
43+
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
44+
return CompressedTensorsWNA16MoEMethod(quant_config)
45+
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
46+
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
47+
else:
48+
raise RuntimeError(
49+
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
50+
51+
52+
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
53+
54+
def __init__(
55+
self,
56+
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
57+
):
58+
self.quant_config = quant_config
59+
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
60+
"weights")
61+
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
62+
"input_activations")
63+
64+
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
65+
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
66+
raise ValueError(
67+
"For FP8 Fused MoE layers, only per-tensor scales"
68+
"for weights and activations are supported. Found "
69+
f"{self.weight_quant}, {self.input_quant}")
70+
71+
self.static_input_scales = not self.input_quant.dynamic
72+
73+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
74+
hidden_size: int, intermediate_size: int,
75+
params_dtype: torch.dtype, **extra_weight_attrs):
76+
77+
params_dtype = torch.float8_e4m3fn
78+
79+
# WEIGHTS
80+
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
81+
2 * intermediate_size,
82+
hidden_size,
83+
dtype=params_dtype),
84+
requires_grad=False)
85+
layer.register_parameter("w13_weight", w13_weight)
86+
set_weight_attrs(w13_weight, extra_weight_attrs)
87+
88+
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
89+
hidden_size,
90+
intermediate_size,
91+
dtype=params_dtype),
92+
requires_grad=False)
93+
layer.register_parameter("w2_weight", w2_weight)
94+
set_weight_attrs(w2_weight, extra_weight_attrs)
95+
96+
# WEIGHT_SCALES
97+
# Allocate 2 scales for w1 and w3 respectively.
98+
# They will be combined to a single scale after weight loading.
99+
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
100+
2,
101+
dtype=torch.float32),
102+
requires_grad=False)
103+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
104+
105+
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
106+
dtype=torch.float32),
107+
requires_grad=False)
108+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
109+
# Add the quantization method used (per tensor/grouped/channel)
110+
# to ensure the weight scales are loaded in properly
111+
extra_weight_attrs.update(
112+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
113+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
114+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
115+
116+
# INPUT_SCALES
117+
if self.static_input_scales:
118+
w13_input_scale = torch.nn.Parameter(torch.ones(
119+
num_experts, dtype=torch.float32),
120+
requires_grad=False)
121+
layer.register_parameter("w13_input_scale", w13_input_scale)
122+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
123+
124+
w2_input_scale = torch.nn.Parameter(torch.ones(
125+
num_experts, dtype=torch.float32),
126+
requires_grad=False)
127+
layer.register_parameter("w2_input_scale", w2_input_scale)
128+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
129+
else:
130+
layer.w13_input_scale = None
131+
layer.w2_input_scale = None
132+
133+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
134+
# Fp8 moe kernels require a single activation scale.
135+
# We take the max of all the scales in case they differ.
136+
if self.static_input_scales:
137+
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
138+
raise ValueError(
139+
"QuantConfig has static quantization, but found "
140+
"activation scales are None.")
141+
if (not all_close_1d(layer.w13_input_scale)
142+
or not all_close_1d(layer.w2_input_scale)):
143+
print_warning_once(
144+
"Found input_scales that are not equal for "
145+
"fp8 MoE layer. Using the maximum across experts "
146+
"for each layer. ")
147+
layer.w13_input_scale = torch.nn.Parameter(
148+
layer.w13_input_scale.max(), requires_grad=False)
149+
layer.w2_input_scale = torch.nn.Parameter(
150+
layer.w2_input_scale.max(), requires_grad=False)
151+
152+
# If rocm, normalize the weights and scales to e4m3fnuz
153+
if is_hip():
154+
# Normalize the weights and scales
155+
w13_weight, w13_weight_scale, w13_input_scale = \
156+
normalize_e4m3fn_to_e4m3fnuz(
157+
layer.w13_weight, layer.w13_weight_scale,
158+
layer.w13_input_scale)
159+
w2_weight, w2_weight_scale, w2_input_scale = \
160+
normalize_e4m3fn_to_e4m3fnuz(
161+
layer.w2_weight, layer.w2_weight_scale,
162+
layer.w2_input_scale)
163+
# Reset the parameter
164+
layer.w13_weight = torch.nn.Parameter(w13_weight,
165+
requires_grad=False)
166+
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
167+
requires_grad=False)
168+
if w13_input_scale is not None:
169+
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
170+
requires_grad=False)
171+
layer.w2_weight = torch.nn.Parameter(w2_weight,
172+
requires_grad=False)
173+
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
174+
requires_grad=False)
175+
if w2_input_scale is not None:
176+
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
177+
requires_grad=False)
178+
179+
# Fp8 moe kernel needs single weight scale for w13 per expert.
180+
# We take the max then dequant and requant each expert.
181+
assert layer.w13_weight_scale is not None
182+
shard_size = layer.intermediate_size_per_partition
183+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
184+
for expert_id in range(layer.num_experts):
185+
start = 0
186+
for shard_id in range(2):
187+
dq_weight = per_tensor_dequantize(
188+
layer.w13_weight[expert_id][start:start + shard_size, :],
189+
layer.w13_weight_scale[expert_id][shard_id])
190+
layer.w13_weight[expert_id][
191+
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
192+
dq_weight, max_w13_scales[expert_id])
193+
start += shard_size
194+
195+
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
196+
requires_grad=False)
197+
198+
def apply(
199+
self,
200+
layer: torch.nn.Module,
201+
x: torch.Tensor,
202+
router_logits: torch.Tensor,
203+
top_k: int,
204+
renormalize: bool = True,
205+
use_grouped_topk: bool = False,
206+
num_expert_group: Optional[int] = None,
207+
topk_group: Optional[int] = None,
208+
custom_routing_function: Optional[Callable] = None,
209+
) -> torch.Tensor:
210+
211+
from vllm.model_executor.layers.fused_moe import fused_experts
212+
213+
topk_weights, topk_ids = FusedMoE.select_experts(
214+
hidden_states=x,
215+
router_logits=router_logits,
216+
use_grouped_topk=use_grouped_topk,
217+
top_k=top_k,
218+
renormalize=renormalize,
219+
topk_group=topk_group,
220+
num_expert_group=num_expert_group,
221+
custom_routing_function=custom_routing_function)
222+
223+
return fused_experts(x,
224+
layer.w13_weight,
225+
layer.w2_weight,
226+
topk_weights=topk_weights,
227+
topk_ids=topk_ids,
228+
inplace=True,
229+
use_fp8_w8a8=True,
230+
w1_scale=layer.w13_weight_scale,
231+
w2_scale=layer.w2_weight_scale,
232+
a1_scale=layer.w13_input_scale,
233+
a2_scale=layer.w2_input_scale)
234+
235+
236+
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
237+
26238
def __init__(
27239
self,
28240
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501

vllm/model_executor/models/phimoe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,13 @@ def __init__(
321321
self.total_num_heads,
322322
self.total_num_kv_heads,
323323
bias=True,
324-
quant_config=None,
324+
quant_config=quant_config,
325325
)
326326
self.o_proj = RowParallelLinear(
327327
self.total_num_heads * self.head_dim,
328328
hidden_size,
329329
bias=True,
330-
quant_config=None,
330+
quant_config=quant_config,
331331
)
332332
self.rotary_emb = get_rope(
333333
self.head_dim,

0 commit comments

Comments
 (0)