1
1
from typing import Callable , List , Optional
2
2
3
3
import torch
4
+ from torch .nn import Parameter
4
5
5
6
from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
6
7
CompressedTensorsScheme )
8
+ from vllm .model_executor .layers .quantization .compressed_tensors .utils import (
9
+ QuantizationStrategy )
7
10
from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
8
- apply_fp8_linear , create_per_tensor_scale_param , cutlass_fp8_supported ,
11
+ apply_fp8_linear , create_per_channel_scale_param ,
12
+ create_per_tensor_scale_param , cutlass_fp8_supported ,
9
13
requantize_with_max_scale )
10
14
from vllm .model_executor .utils import set_weight_attrs
11
15
14
18
15
19
class CompressedTensorsW8A8Fp8 (CompressedTensorsScheme ):
16
20
17
- def __init__ (self , input_dynamic : bool ):
18
- self .input_dynamic = input_dynamic
21
+ def __init__ (self , strategy : str , is_static_input_scheme : bool ):
22
+ self .strategy = strategy
23
+ self .is_static_input_scheme = is_static_input_scheme
19
24
self .cutlass_fp8_supported = cutlass_fp8_supported ()
20
25
21
- # W8A8-Fp8 kernels support only per-tensor and per-channel cases.
22
- # So if we have a fused module (QKV, MLP) with per tensor scales (thus N
23
- # scales being passed to the kernel), we requantize with a single scale.
26
+ # On Lovelace, fail for now if channelwise.
27
+ # TODO: (@tms) fallback
28
+ if (not self .cutlass_fp8_supported
29
+ and self .strategy == QuantizationStrategy .CHANNEL ):
30
+ raise ValueError (
31
+ "Channelwise fp8 quantization requires vLLM's custom "
32
+ "cutlass kernels, which are not supported on your device."
33
+ "Consider quantizing with per tensor scales or upgrading "
34
+ "to Hopper." )
35
+
24
36
def process_weights_after_loading (self , layer ) -> None :
25
- # Dequant -> Quant with max scale.
26
- max_w_scale , weight = requantize_with_max_scale (
27
- weight = layer .weight ,
28
- weight_scale = layer .weight_scale ,
29
- logical_widths = layer .logical_widths ,
30
- )
31
-
32
- # Update layer with new values.
33
- layer .weight = torch .nn .Parameter (weight .t (), requires_grad = False )
34
- layer .weight_scale = torch .nn .Parameter (max_w_scale ,
35
- requires_grad = False )
36
- if self .input_dynamic :
37
- layer .input_scale = None
37
+ # If per tensor, when we have a fused module (e.g. QKV) with per
38
+ # tensor scales (thus N scales being passed to the kernel),
39
+ # requantize so we can always run per tensor
40
+ if self .strategy == QuantizationStrategy .TENSOR :
41
+ max_w_scale , weight = requantize_with_max_scale (
42
+ weight = layer .weight ,
43
+ weight_scale = layer .weight_scale ,
44
+ logical_widths = layer .logical_widths ,
45
+ )
46
+
47
+ layer .weight = Parameter (weight .t (), requires_grad = False )
48
+ layer .weight_scale = Parameter (max_w_scale , requires_grad = False )
49
+
50
+ # If channelwise, scales are already lined up, so just transpose.
51
+ elif self .strategy == QuantizationStrategy .CHANNEL :
52
+ assert self .cutlass_fp8_supported
53
+ weight = layer .weight
54
+ layer .weight = Parameter (weight .t (), requires_grad = False )
55
+
56
+ else :
57
+ raise ValueError (f"Unknown quantization strategy { self .strategy } " )
58
+
59
+ # INPUT SCALE
60
+ if self .is_static_input_scheme :
61
+ layer .input_scale = Parameter (layer .input_scale .max (),
62
+ requires_grad = False )
38
63
else :
39
- layer .input_scale = torch .nn .Parameter (layer .input_scale .max (),
40
- requires_grad = False )
64
+ layer .input_scale = None
41
65
42
66
def create_weights (self , layer : torch .nn .Module ,
43
67
output_partition_sizes : List [int ],
44
68
input_size_per_partition : int ,
45
69
params_dtype : torch .dtype , weight_loader : Callable ,
46
70
** kwargs ):
47
-
48
- del params_dtype
49
-
50
71
output_size_per_partition = sum (output_partition_sizes )
51
72
layer .logical_widths = output_partition_sizes
52
73
@@ -63,12 +84,17 @@ def create_weights(self, layer: torch.nn.Module,
63
84
})
64
85
65
86
# WEIGHT SCALE
66
- weight_scale = create_per_tensor_scale_param (
67
- output_partition_sizes , weight_loader = weight_loader )
87
+ if self .strategy == QuantizationStrategy .CHANNEL :
88
+ weight_scale = create_per_channel_scale_param (
89
+ output_partition_sizes , weight_loader = weight_loader )
90
+ else :
91
+ assert self .strategy == QuantizationStrategy .TENSOR
92
+ weight_scale = create_per_tensor_scale_param (
93
+ output_partition_sizes , weight_loader = weight_loader )
68
94
layer .register_parameter ("weight_scale" , weight_scale )
69
95
70
96
# INPUT SCALE
71
- if not self .input_dynamic :
97
+ if self .is_static_input_scheme :
72
98
input_scale = create_per_tensor_scale_param (
73
99
output_partition_sizes , weight_loader = weight_loader )
74
100
layer .register_parameter ("input_scale" , input_scale )
0 commit comments