Skip to content

Commit ad6b8bc

Browse files
mgoinmawong-amd
authored andcommitted
Fix FBGEMM integration (vllm-project#18002)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 09adc4a commit ad6b8bc

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

vllm/model_executor/layers/quantization/fbgemm_fp8.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
6363
def get_quant_method(self, layer: torch.nn.Module,
6464
prefix: str) -> Optional["QuantizeMethodBase"]:
6565
if isinstance(layer, LinearBase):
66-
if is_layer_skipped(prefix, self.ignore_list):
66+
if is_layer_skipped(prefix=prefix,
67+
ignored_layers=self.ignore_list,
68+
fused_mapping=self.packed_modules_mapping):
6769
return UnquantizedLinearMethod()
6870
return FBGEMMFp8LinearMethod(self)
6971
return None

vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
8686

8787
part_size_n = layer.output_size_per_partition
8888
part_size_k = layer.input_size_per_partition
89+
weight_block_size = getattr(layer, "weight_block_size", None)
8990

9091
if size_k_first:
9192
assert layer.weight.shape == (part_size_k, part_size_n)
@@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
119120
scales = layer.weight_scale_inv.to(layer.orig_dtype)
120121
del layer.weight_scale_inv
121122

122-
if layer.weight_block_size is None:
123-
group_size = -1
124-
else:
125-
group_size = layer.weight_block_size[1]
123+
group_size = -1 if weight_block_size is None else weight_block_size[1]
126124

127125
# marlin kernel only support channel-wise and group-wise quantization
128126
# we need to convert the scales
129-
if layer.weight_block_size is None:
127+
if weight_block_size is None:
130128
if scales.nelement() == 1:
131129
# tensor-wise quantization -> channel-wise quantization
132130
# (1, 1) =>(repeat)=> (1, size_n)
@@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
149147
# =>(repeat)=> (size_k // block_size[1], size_n)
150148
if not size_k_first:
151149
scales = scales.T.contiguous()
152-
block_n = layer.weight_block_size[0]
150+
block_n = weight_block_size[0]
153151
scales = scales.repeat_interleave(block_n, 1)
154152
# size_n may not divisible by block_size[0]
155153
scales = scales[:, :part_size_n]
@@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
173171
e = layer.num_experts
174172
k = layer.hidden_size
175173
n = layer.intermediate_size_per_partition
174+
weight_block_size = getattr(layer, "weight_block_size", None)
176175

177176
# WORKSPACE
178177
device = layer.w13_weight.device
@@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
213212

214213
# WEIGHT SCALES
215214
# Permute scales
216-
if layer.weight_block_size is None:
217-
group_size = -1
218-
else:
219-
group_size = layer.weight_block_size[1]
215+
group_size = -1 if weight_block_size is None else weight_block_size[1]
220216

221217
for name in ["w13", "w2"]:
222218
if name + "_weight_scale" in dir(layer):
@@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
236232

237233
# marlin kernel only support channel-wise and group-wise quantization
238234
# we need to convert the scales
239-
if layer.weight_block_size is None:
235+
if weight_block_size is None:
240236
if scales.nelement() == e:
241237
# tensor-wise quantization -> channel-wise quantization
242238
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
@@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
259255
# =>(repeat)=> (e, size_k // block_size[1], size_n)
260256
if not size_k_first:
261257
scales = scales.permute(0, 2, 1)
262-
block_n = layer.weight_block_size[0]
258+
block_n = weight_block_size[0]
263259
scales = scales.repeat_interleave(block_n, 2)
264260
# size_n may not divisible by block_size[0]
265261
scales = scales[..., :size_n].contiguous()

0 commit comments

Comments
 (0)