@@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
86
86
87
87
part_size_n = layer .output_size_per_partition
88
88
part_size_k = layer .input_size_per_partition
89
+ weight_block_size = getattr (layer , "weight_block_size" , None )
89
90
90
91
if size_k_first :
91
92
assert layer .weight .shape == (part_size_k , part_size_n )
@@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
119
120
scales = layer .weight_scale_inv .to (layer .orig_dtype )
120
121
del layer .weight_scale_inv
121
122
122
- if layer .weight_block_size is None :
123
- group_size = - 1
124
- else :
125
- group_size = layer .weight_block_size [1 ]
123
+ group_size = - 1 if weight_block_size is None else weight_block_size [1 ]
126
124
127
125
# marlin kernel only support channel-wise and group-wise quantization
128
126
# we need to convert the scales
129
- if layer . weight_block_size is None :
127
+ if weight_block_size is None :
130
128
if scales .nelement () == 1 :
131
129
# tensor-wise quantization -> channel-wise quantization
132
130
# (1, 1) =>(repeat)=> (1, size_n)
@@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
149
147
# =>(repeat)=> (size_k // block_size[1], size_n)
150
148
if not size_k_first :
151
149
scales = scales .T .contiguous ()
152
- block_n = layer . weight_block_size [0 ]
150
+ block_n = weight_block_size [0 ]
153
151
scales = scales .repeat_interleave (block_n , 1 )
154
152
# size_n may not divisible by block_size[0]
155
153
scales = scales [:, :part_size_n ]
@@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
173
171
e = layer .num_experts
174
172
k = layer .hidden_size
175
173
n = layer .intermediate_size_per_partition
174
+ weight_block_size = getattr (layer , "weight_block_size" , None )
176
175
177
176
# WORKSPACE
178
177
device = layer .w13_weight .device
@@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
213
212
214
213
# WEIGHT SCALES
215
214
# Permute scales
216
- if layer .weight_block_size is None :
217
- group_size = - 1
218
- else :
219
- group_size = layer .weight_block_size [1 ]
215
+ group_size = - 1 if weight_block_size is None else weight_block_size [1 ]
220
216
221
217
for name in ["w13" , "w2" ]:
222
218
if name + "_weight_scale" in dir (layer ):
@@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
236
232
237
233
# marlin kernel only support channel-wise and group-wise quantization
238
234
# we need to convert the scales
239
- if layer . weight_block_size is None :
235
+ if weight_block_size is None :
240
236
if scales .nelement () == e :
241
237
# tensor-wise quantization -> channel-wise quantization
242
238
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
@@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
259
255
# =>(repeat)=> (e, size_k // block_size[1], size_n)
260
256
if not size_k_first :
261
257
scales = scales .permute (0 , 2 , 1 )
262
- block_n = layer . weight_block_size [0 ]
258
+ block_n = weight_block_size [0 ]
263
259
scales = scales .repeat_interleave (block_n , 2 )
264
260
# size_n may not divisible by block_size[0]
265
261
scales = scales [..., :size_n ].contiguous ()
0 commit comments