|
9 | 9 | UnquantizedLinearMethod)
|
10 | 10 | from vllm.model_executor.layers.quantization.base_config import (
|
11 | 11 | QuantizationConfig, QuantizeMethodBase)
|
| 12 | +from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported |
12 | 13 | from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
13 | 14 | apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
14 | 15 | from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
@@ -72,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
72 | 73 |
|
73 | 74 | def __init__(self, quant_config: FBGEMMFp8Config):
|
74 | 75 | self.quant_config = quant_config
|
| 76 | + self.cutlass_fp8_supported = cutlass_fp8_supported() |
75 | 77 |
|
76 | 78 | def create_weights(
|
77 | 79 | self,
|
@@ -139,11 +141,12 @@ def apply(self,
|
139 | 141 | size_k=layer.input_size_per_partition,
|
140 | 142 | bias=bias)
|
141 | 143 |
|
142 |
| - return apply_fp8_linear(input=x, |
143 |
| - weight=layer.weight, |
144 |
| - weight_scale=layer.weight_scale, |
145 |
| - input_scale=None, |
146 |
| - input_scale_ub=layer.input_scale_ub, |
147 |
| - bias=bias, |
148 |
| - cutlass_fp8_supported=True, |
149 |
| - use_per_token_if_dynamic=True) |
| 144 | + return apply_fp8_linear( |
| 145 | + input=x, |
| 146 | + weight=layer.weight, |
| 147 | + weight_scale=layer.weight_scale, |
| 148 | + input_scale=None, |
| 149 | + input_scale_ub=layer.input_scale_ub, |
| 150 | + bias=bias, |
| 151 | + cutlass_fp8_supported=self.cutlass_fp8_supported, |
| 152 | + use_per_token_if_dynamic=True) |
0 commit comments