8
8
QuantizationConfig )
9
9
from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
10
10
W4A16SPARSE24_SUPPORTED_BITS , WNA16_SUPPORTED_BITS ,
11
- CompressedTensorsScheme , CompressedTensorsW4A16Sparse24 ,
12
- CompressedTensorsW8A8Fp8 , CompressedTensorsW8A8Int8 ,
13
- CompressedTensorsWNA16 )
11
+ CompressedTensorsScheme , CompressedTensorsUnquantized ,
12
+ CompressedTensorsW4A16Sparse24 , CompressedTensorsW8A8Fp8 ,
13
+ CompressedTensorsW8A8Int8 , CompressedTensorsWNA16 )
14
14
from vllm .model_executor .layers .quantization .compressed_tensors .utils import (
15
15
CompressionFormat , QuantizationArgs , QuantizationStrategy ,
16
- QuantizationType , find_first_name_or_class_match ,
17
- is_activation_quantization_format )
16
+ QuantizationType , find_matched_target , is_activation_quantization_format ,
17
+ should_ignore_layer )
18
18
from vllm .platforms import current_platform
19
19
20
20
21
21
class CompressedTensorsConfig (QuantizationConfig ):
22
22
23
- def __init__ (self , layer_quant_details : Dict [str , Any ], ignore : List [str ],
23
+ def __init__ (self , target_scheme_map : Dict [str , Any ], ignore : List [str ],
24
24
quant_format : str ):
25
+
25
26
self .ignore = ignore
26
- self .layer_quant_details = layer_quant_details
27
27
self .quant_format = quant_format
28
+ # Map from [target -> scheme]
29
+ self .target_scheme_map = target_scheme_map
28
30
29
31
def get_linear_method (self ) -> "CompressedTensorsLinearMethod" :
30
32
return CompressedTensorsLinearMethod (self )
@@ -51,7 +53,7 @@ def get_quant_method(
51
53
52
54
@classmethod
53
55
def from_config (cls , config : Dict [str , Any ]) -> "CompressedTensorsConfig" :
54
- layer_quant_details : Dict [str , Any ] = dict ()
56
+ target_scheme_map : Dict [str , Any ] = dict ()
55
57
ignore : List [str ] = config .get ("ignore" , None )
56
58
quant_format : str = config .get ("format" , None )
57
59
@@ -63,21 +65,21 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
63
65
# details follow the structure defined by the QuantizationArgs
64
66
# pydantic model, which is used to verify the structure of the
65
67
# quant_config and also store the details for later use.
66
- for key , quant_config in config ["config_groups" ].items ():
68
+ for _ , quant_config in config ["config_groups" ].items ():
67
69
targets = quant_config .get ("targets" )
68
70
for target in targets :
69
- layer_quant_details [target ] = {}
70
- layer_quant_details [target ][
71
+ target_scheme_map [target ] = {}
72
+ target_scheme_map [target ][
71
73
"weights" ] = QuantizationArgs .parse_obj (
72
74
quant_config .get ("weights" ))
73
75
try :
74
- layer_quant_details [target ][
76
+ target_scheme_map [target ][
75
77
"input_activations" ] = QuantizationArgs .parse_obj (
76
78
quant_config .get ("input_activations" ))
77
79
except Exception :
78
- layer_quant_details [target ]["input_activations" ] = None
80
+ target_scheme_map [target ]["input_activations" ] = None
79
81
80
- return cls (layer_quant_details = layer_quant_details ,
82
+ return cls (target_scheme_map = target_scheme_map ,
81
83
ignore = ignore ,
82
84
quant_format = quant_format )
83
85
@@ -167,8 +169,9 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
167
169
return (is_channel_group and input_quant_none and is_symmetric
168
170
and is_static )
169
171
170
- def _get_schema (self , weight_quant : BaseModel ,
171
- input_quant : BaseModel ) -> "CompressedTensorsScheme" :
172
+ def _get_scheme_from_parts (
173
+ self , weight_quant : BaseModel ,
174
+ input_quant : BaseModel ) -> "CompressedTensorsScheme" :
172
175
173
176
# Detect If Mixed Precision
174
177
if self ._is_wNa16_group_channel (weight_quant , input_quant ):
@@ -205,26 +208,47 @@ def _get_schema(self, weight_quant: BaseModel,
205
208
raise NotImplementedError (
206
209
"No compressed-tensors compatible scheme was found." )
207
210
208
- def get_scheme (self , layer : torch .nn .Module ) -> "CompressedTensorsScheme" :
211
+ def get_scheme (
212
+ self ,
213
+ layer : torch .nn .Module ,
214
+ layer_name : Optional [str ] = None ) -> "CompressedTensorsScheme" :
215
+ """
216
+ compressed-tensors supports non uniform in the following way:
217
+
218
+ ignore: List of layer_names or nn.Module names to be ignored.
219
+ targets of config_groups: There can be N config_groups which each
220
+ have a quantization scheme. Each config_group has a list of targets
221
+ which can be a full layer_name, a regex for a layer_name, or
222
+ an nn.Module name.
209
223
210
- layer_type_name = find_first_name_or_class_match (
211
- name = "" ,
212
- module = layer ,
213
- targets = self .layer_quant_details .keys (),
214
- check_contains = True )
224
+ We first check whether a layer is in the ignore group and use
225
+ CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
215
226
216
- if layer_type_name is None :
217
- raise ValueError (f"Could not matching target for layer { layer } " )
227
+ We then detect whether a layer_name is found in any target and
228
+ use the quantization scheme corresponding to the matched target
229
+ to select the CompressedTensorsScheme used for infernece.
230
+ """
231
+
232
+ # Check if the layer is skipped for quantization.
233
+ # TODO (@robertgshaw2): support module names
234
+ if should_ignore_layer (layer_name , ignore = self .ignore ):
235
+ return CompressedTensorsUnquantized ()
236
+
237
+ # Find the "target" in the compressed-tensors config
238
+ # that our layer conforms to.
239
+ # TODO (@robertgshaw): add compressed-tensors as dep
240
+ # so we do not have to re-write these functions
241
+ matched_target = find_matched_target (
242
+ layer_name = layer_name ,
243
+ module = layer ,
244
+ targets = self .target_scheme_map .keys ())
218
245
219
- layer_quant_details : Dict [str , Any ] = self .layer_quant_details .get (
220
- layer_type_name , None )
221
- if layer_quant_details is None :
222
- raise ValueError (
223
- f"Could not find quantization details for { layer } ." )
246
+ # Find the quant_scheme
247
+ scheme = self .target_scheme_map [matched_target ]
224
248
225
- scheme = self ._get_schema (
226
- weight_quant = layer_quant_details ["weights" ],
227
- input_quant = layer_quant_details ["input_activations" ])
249
+ return self ._get_scheme_from_parts (
250
+ weight_quant = scheme ["weights" ],
251
+ input_quant = scheme ["input_activations" ])
228
252
229
253
# Raise error if device does not support the scheme
230
254
# (e.g. fp8 needs ada lovelace)
@@ -250,11 +274,11 @@ def create_weights(self, layer: torch.nn.Module,
250
274
Use the CompressedTensorsScheme associated with each layer to create
251
275
the necessary parameters for the layer. See LinearMethodBase for param
252
276
details
253
-
254
277
"""
255
278
weight_loader = extra_weight_attrs .get ("weight_loader" )
279
+ layer_name = extra_weight_attrs .get ("prefix" )
256
280
257
- scheme = self .quantization_config .get_scheme (layer = layer )
281
+ scheme = self .quantization_config .get_scheme (layer , layer_name )
258
282
scheme .create_weights (
259
283
layer = layer ,
260
284
input_size = input_size ,
0 commit comments