Skip to content

Commit dbe5588

Browse files
[ Misc ] non-uniform quantization via compressed-tensors for Llama (#6515)
1 parent d4201e0 commit dbe5588

File tree

11 files changed

+300
-90
lines changed

11 files changed

+300
-90
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
2+
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.758
8+
- name: "exact_match,flexible-extract"
9+
value: 0.759
10+
limit: 1000
11+
num_fewshot: 5

.buildkite/lm-eval-harness/configs/models-small.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml
22
Meta-Llama-3-8B-Instruct-FP8.yaml
33
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
44
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
5+
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
56
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
topk_group: Optional[int] = None,
159159
quant_config: Optional[QuantizationConfig] = None,
160160
tp_size: Optional[int] = None,
161+
prefix: str = "",
161162
):
162163
super().__init__()
163164

vllm/model_executor/layers/linear.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase):
171171
skip_bias_add: If true, skip adding bias but instead return it.
172172
params_dtype: Data type for the parameters.
173173
quant_config: Quantization configure.
174+
prefix: The name of the layer in the state dict, including all parents
175+
(e.g. model.layers.0.qkv_proj)
174176
"""
175177

176178
def __init__(self,
@@ -179,15 +181,19 @@ def __init__(self,
179181
bias: bool = True,
180182
skip_bias_add: bool = False,
181183
params_dtype: Optional[torch.dtype] = None,
182-
quant_config: Optional[QuantizationConfig] = None):
184+
quant_config: Optional[QuantizationConfig] = None,
185+
prefix: Optional[str] = None):
183186
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
184187
quant_config)
185188

186189
# All the linear layer supports quant method.
187190
assert self.quant_method is not None
188-
self.quant_method.create_weights(self, self.input_size,
189-
[self.output_size], self.input_size,
190-
self.output_size, self.params_dtype)
191+
self.quant_method.create_weights(self,
192+
self.input_size, [self.output_size],
193+
self.input_size,
194+
self.output_size,
195+
self.params_dtype,
196+
prefix=prefix)
191197

192198
if bias:
193199
self.bias = Parameter(
@@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase):
239245
quant_config: Quantization configure.
240246
output_sizes: list of output sizes packed into one output, like for QKV
241247
the list would be size 3.
248+
prefix: The name of the layer in the state dict, including all parents
249+
(e.g. model.layers.0.qkv_proj)
242250
"""
243251

244252
def __init__(self,
@@ -249,7 +257,8 @@ def __init__(self,
249257
skip_bias_add: bool = False,
250258
params_dtype: Optional[torch.dtype] = None,
251259
quant_config: Optional[QuantizationConfig] = None,
252-
output_sizes: Optional[List[int]] = None):
260+
output_sizes: Optional[List[int]] = None,
261+
prefix: Optional[str] = None):
253262
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
254263
quant_config)
255264

@@ -276,7 +285,8 @@ def __init__(self,
276285
input_size=self.input_size,
277286
output_size=self.output_size,
278287
params_dtype=self.params_dtype,
279-
weight_loader=self.weight_loader)
288+
weight_loader=self.weight_loader,
289+
prefix=prefix)
280290
if bias:
281291
self.bias = Parameter(
282292
torch.empty(self.output_size_per_partition,
@@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
348358
skip adding bias but instead return it.
349359
params_dtype: Data type for the parameters.
350360
quant_config: Quantization configure.
361+
prefix: The name of the layer in the state dict, including all parents
362+
(e.g. model.layers.0.qkv_proj)
351363
"""
352364

353365
def __init__(self,
@@ -357,7 +369,8 @@ def __init__(self,
357369
gather_output: bool = False,
358370
skip_bias_add: bool = False,
359371
params_dtype: Optional[torch.dtype] = None,
360-
quant_config: Optional[QuantizationConfig] = None):
372+
quant_config: Optional[QuantizationConfig] = None,
373+
prefix: Optional[str] = None):
361374
self.output_sizes = output_sizes
362375
tp_size = get_tensor_model_parallel_world_size()
363376
assert all(output_size % tp_size == 0 for output_size in output_sizes)
@@ -367,7 +380,8 @@ def __init__(self,
367380
gather_output=gather_output,
368381
skip_bias_add=skip_bias_add,
369382
params_dtype=params_dtype,
370-
quant_config=quant_config)
383+
quant_config=quant_config,
384+
prefix=prefix)
371385

372386
def weight_loader(self,
373387
param: Parameter,
@@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear):
487501
skip adding bias but instead return it.
488502
params_dtype: Data type for the parameters.
489503
quant_config: Quantization configure.
504+
prefix: The name of the layer in the state dict, including all parents
505+
(e.g. model.layers.0.qkv_proj)
490506
"""
491507

492508
def __init__(self,
@@ -497,7 +513,8 @@ def __init__(self,
497513
bias: bool = True,
498514
skip_bias_add: bool = False,
499515
params_dtype: Optional[torch.dtype] = None,
500-
quant_config: Optional[QuantizationConfig] = None):
516+
quant_config: Optional[QuantizationConfig] = None,
517+
prefix: Optional[str] = None):
501518
self.hidden_size = hidden_size
502519
self.head_size = head_size
503520
self.total_num_heads = total_num_heads
@@ -529,7 +546,8 @@ def __init__(self,
529546
gather_output=False,
530547
skip_bias_add=skip_bias_add,
531548
params_dtype=params_dtype,
532-
quant_config=quant_config)
549+
quant_config=quant_config,
550+
prefix=prefix)
533551

534552
def weight_loader(self,
535553
param: Parameter,
@@ -688,7 +706,8 @@ def __init__(self,
688706
skip_bias_add: bool = False,
689707
params_dtype: Optional[torch.dtype] = None,
690708
reduce_results: bool = True,
691-
quant_config: Optional[QuantizationConfig] = None):
709+
quant_config: Optional[QuantizationConfig] = None,
710+
prefix: Optional[str] = None):
692711
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
693712
quant_config)
694713

@@ -706,7 +725,8 @@ def __init__(self,
706725
input_size=self.input_size,
707726
output_size=self.output_size,
708727
params_dtype=self.params_dtype,
709-
weight_loader=self.weight_loader)
728+
weight_loader=self.weight_loader,
729+
prefix=prefix)
710730
if not reduce_results and (bias and not skip_bias_add):
711731
raise ValueError("When not reduce the results, adding bias to the "
712732
"results can lead to incorrect results")

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,25 @@
88
QuantizationConfig)
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1010
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
11-
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
12-
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
13-
CompressedTensorsWNA16)
11+
CompressedTensorsScheme, CompressedTensorsUnquantized,
12+
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
13+
CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
1414
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1515
CompressionFormat, QuantizationArgs, QuantizationStrategy,
16-
QuantizationType, find_first_name_or_class_match,
17-
is_activation_quantization_format)
16+
QuantizationType, find_matched_target, is_activation_quantization_format,
17+
should_ignore_layer)
1818
from vllm.platforms import current_platform
1919

2020

2121
class CompressedTensorsConfig(QuantizationConfig):
2222

23-
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
23+
def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
2424
quant_format: str):
25+
2526
self.ignore = ignore
26-
self.layer_quant_details = layer_quant_details
2727
self.quant_format = quant_format
28+
# Map from [target -> scheme]
29+
self.target_scheme_map = target_scheme_map
2830

2931
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
3032
return CompressedTensorsLinearMethod(self)
@@ -51,7 +53,7 @@ def get_quant_method(
5153

5254
@classmethod
5355
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
54-
layer_quant_details: Dict[str, Any] = dict()
56+
target_scheme_map: Dict[str, Any] = dict()
5557
ignore: List[str] = config.get("ignore", None)
5658
quant_format: str = config.get("format", None)
5759

@@ -63,21 +65,21 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
6365
# details follow the structure defined by the QuantizationArgs
6466
# pydantic model, which is used to verify the structure of the
6567
# quant_config and also store the details for later use.
66-
for key, quant_config in config["config_groups"].items():
68+
for _, quant_config in config["config_groups"].items():
6769
targets = quant_config.get("targets")
6870
for target in targets:
69-
layer_quant_details[target] = {}
70-
layer_quant_details[target][
71+
target_scheme_map[target] = {}
72+
target_scheme_map[target][
7173
"weights"] = QuantizationArgs.parse_obj(
7274
quant_config.get("weights"))
7375
try:
74-
layer_quant_details[target][
76+
target_scheme_map[target][
7577
"input_activations"] = QuantizationArgs.parse_obj(
7678
quant_config.get("input_activations"))
7779
except Exception:
78-
layer_quant_details[target]["input_activations"] = None
80+
target_scheme_map[target]["input_activations"] = None
7981

80-
return cls(layer_quant_details=layer_quant_details,
82+
return cls(target_scheme_map=target_scheme_map,
8183
ignore=ignore,
8284
quant_format=quant_format)
8385

@@ -167,8 +169,9 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
167169
return (is_channel_group and input_quant_none and is_symmetric
168170
and is_static)
169171

170-
def _get_schema(self, weight_quant: BaseModel,
171-
input_quant: BaseModel) -> "CompressedTensorsScheme":
172+
def _get_scheme_from_parts(
173+
self, weight_quant: BaseModel,
174+
input_quant: BaseModel) -> "CompressedTensorsScheme":
172175

173176
# Detect If Mixed Precision
174177
if self._is_wNa16_group_channel(weight_quant, input_quant):
@@ -205,26 +208,47 @@ def _get_schema(self, weight_quant: BaseModel,
205208
raise NotImplementedError(
206209
"No compressed-tensors compatible scheme was found.")
207210

208-
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
211+
def get_scheme(
212+
self,
213+
layer: torch.nn.Module,
214+
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
215+
"""
216+
compressed-tensors supports non uniform in the following way:
217+
218+
ignore: List of layer_names or nn.Module names to be ignored.
219+
targets of config_groups: There can be N config_groups which each
220+
have a quantization scheme. Each config_group has a list of targets
221+
which can be a full layer_name, a regex for a layer_name, or
222+
an nn.Module name.
209223
210-
layer_type_name = find_first_name_or_class_match(
211-
name="",
212-
module=layer,
213-
targets=self.layer_quant_details.keys(),
214-
check_contains=True)
224+
We first check whether a layer is in the ignore group and use
225+
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
215226
216-
if layer_type_name is None:
217-
raise ValueError(f"Could not matching target for layer {layer}")
227+
We then detect whether a layer_name is found in any target and
228+
use the quantization scheme corresponding to the matched target
229+
to select the CompressedTensorsScheme used for infernece.
230+
"""
231+
232+
# Check if the layer is skipped for quantization.
233+
# TODO (@robertgshaw2): support module names
234+
if should_ignore_layer(layer_name, ignore=self.ignore):
235+
return CompressedTensorsUnquantized()
236+
237+
# Find the "target" in the compressed-tensors config
238+
# that our layer conforms to.
239+
# TODO (@robertgshaw): add compressed-tensors as dep
240+
# so we do not have to re-write these functions
241+
matched_target = find_matched_target(
242+
layer_name=layer_name,
243+
module=layer,
244+
targets=self.target_scheme_map.keys())
218245

219-
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
220-
layer_type_name, None)
221-
if layer_quant_details is None:
222-
raise ValueError(
223-
f"Could not find quantization details for {layer}.")
246+
# Find the quant_scheme
247+
scheme = self.target_scheme_map[matched_target]
224248

225-
scheme = self._get_schema(
226-
weight_quant=layer_quant_details["weights"],
227-
input_quant=layer_quant_details["input_activations"])
249+
return self._get_scheme_from_parts(
250+
weight_quant=scheme["weights"],
251+
input_quant=scheme["input_activations"])
228252

229253
# Raise error if device does not support the scheme
230254
# (e.g. fp8 needs ada lovelace)
@@ -250,11 +274,11 @@ def create_weights(self, layer: torch.nn.Module,
250274
Use the CompressedTensorsScheme associated with each layer to create
251275
the necessary parameters for the layer. See LinearMethodBase for param
252276
details
253-
254277
"""
255278
weight_loader = extra_weight_attrs.get("weight_loader")
279+
layer_name = extra_weight_attrs.get("prefix")
256280

257-
scheme = self.quantization_config.get_scheme(layer=layer)
281+
scheme = self.quantization_config.get_scheme(layer, layer_name)
258282
scheme.create_weights(
259283
layer=layer,
260284
input_size=input_size,

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def create_weights(self, layer: torch.nn.Module,
3333

3434
weight = Parameter(torch.empty(sum(output_partition_sizes),
3535
input_size_per_partition,
36-
device="cuda",
3736
dtype=params_dtype),
3837
requires_grad=False)
3938

0 commit comments

Comments
 (0)