Skip to content

Commit 172d1cd

Browse files
[Kernel] AQ AZP 4/4: Integrate asymmetric quantization to linear method (#7271)
1 parent a9b15c6 commit 172d1cd

File tree

7 files changed

+124
-21
lines changed

7 files changed

+124
-21
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
2+
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.764
8+
- name: "exact_match,flexible-extract"
9+
value: 0.764
10+
limit: 250
11+
num_fewshot: 5

.buildkite/lm-eval-harness/configs/models-small.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
Meta-Llama-3-8B-Instruct.yaml
22
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
33
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
4+
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
45
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
56
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
67
Minitron-4B-Base-FP8.yaml

.buildkite/lm-eval-harness/test_lm_eval_correctness.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,15 @@ def test_lm_eval_correctness():
4949
results = launch_lm_eval(eval_config)
5050

5151
# Confirm scores match ground truth.
52+
success = True
5253
for task in eval_config["tasks"]:
5354
for metric in task["metrics"]:
5455
ground_truth = metric["value"]
5556
measured_value = results["results"][task["name"]][metric["name"]]
5657
print(f'{task["name"]} | {metric["name"]}: '
5758
f'ground_truth={ground_truth} | measured={measured_value}')
58-
assert numpy.isclose(ground_truth, measured_value, rtol=RTOL)
59+
success = success and numpy.isclose(
60+
ground_truth, measured_value, rtol=RTOL)
61+
62+
# Assert at the end, print all scores even on failure for debugging.
63+
assert success

tests/quantization/test_compressed_tensors.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
Run `pytest tests/quantization/test_compressed_tensors.py`.
44
"""
5+
from typing import Optional
56

67
import pytest
78
import torch
@@ -14,14 +15,16 @@
1415
QuantizationType)
1516

1617

17-
@pytest.mark.parametrize("model_args", [
18-
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
19-
QuantizationType.INT, 2560),
20-
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
21-
QuantizationType.INT, 2560),
22-
])
18+
@pytest.mark.parametrize(
19+
"model_args",
20+
[("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
21+
QuantizationType.INT, 2560, True),
22+
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
23+
QuantizationType.INT, 2560, True),
24+
("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
25+
QuantizationType.INT, 2560, False)])
2326
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
24-
model_path, strategy, quant_type, shape_0 = model_args
27+
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
2528
with vllm_runner(model_path, enforce_eager=True) as llm:
2629
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
2730
layer = model.model.layers[0]
@@ -31,6 +34,18 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
3134
gate_up_proj = layer.mlp.gate_up_proj
3235
down_proj = layer.mlp.down_proj
3336

37+
# assert zp for symmetric and asymmetric cases
38+
def zp_valid(zp: Optional[torch.Tensor]):
39+
if is_symmetric:
40+
return zp is None
41+
42+
return zp is not None and zp.dtype is torch.int32
43+
44+
assert zp_valid(qkv_proj.input_zero_point)
45+
assert zp_valid(o_proj.input_zero_point)
46+
assert zp_valid(gate_up_proj.input_zero_point)
47+
assert zp_valid(down_proj.input_zero_point)
48+
3449
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
3550
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
3651
assert isinstance(gate_up_proj.quant_method,
@@ -69,9 +84,12 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
6984

7085
@pytest.mark.parametrize("model_args", [
7186
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
87+
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
7288
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
89+
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
90+
"channel"),
7391
])
74-
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
92+
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
7593
model_path, strategy = model_args
7694
with vllm_runner(model_path, dtype=torch.float16) as llm:
7795
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
@@ -160,4 +178,4 @@ def test_compressed_tensors_kv_cache(vllm_runner):
160178
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
161179
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
162180
output = llm.generate_greedy("Hello world!", max_tokens=20)
163-
assert output
181+
assert output

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,11 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
138138
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
139139
is_tensor = (weight_strategy and input_quant.strategy
140140
== QuantizationStrategy.TENSOR.value)
141-
is_symmetric = weight_quant.symmetric and input_quant.symmetric
142141
is_static = not weight_quant.dynamic and not input_quant.dynamic
143142

144-
return is_8_bits and is_tensor and is_symmetric and is_static
143+
# Both symmetric and asymmetric input quantization supported.
144+
# Only symmetric weight quantization supported.
145+
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
145146

146147
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
147148
input_quant: BaseModel) -> bool:
@@ -151,10 +152,11 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
151152
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
152153
is_token = (weight_strategy and input_quant.strategy
153154
== QuantizationStrategy.TOKEN.value)
154-
is_symmetric = weight_quant.symmetric and input_quant.symmetric
155155
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
156156

157-
return is_8_bits and is_token and is_symmetric and is_dynamic
157+
# Both symmetric and asymmetric input quantization supported.
158+
# Only symmetric weight quantization supported.
159+
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
158160

159161
def _is_fp8_w8a8(self, weight_quant: BaseModel,
160162
input_quant: BaseModel) -> bool:
@@ -265,12 +267,14 @@ def _get_scheme_from_parts(
265267
if self._is_static_tensor_w8a8(weight_quant, input_quant):
266268
return CompressedTensorsW8A8Int8(
267269
strategy=weight_quant.strategy,
268-
is_static_input_scheme=True)
270+
is_static_input_scheme=True,
271+
input_symmetric=input_quant.symmetric)
269272

270273
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
271274
return CompressedTensorsW8A8Int8(
272275
strategy=weight_quant.strategy,
273-
is_static_input_scheme=False)
276+
is_static_input_scheme=False,
277+
input_symmetric=input_quant.symmetric)
274278

275279
raise NotImplementedError(
276280
"No compressed-tensors compatible scheme was found.")

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from torch.nn import Parameter
55

6+
from vllm.logger import init_logger
67
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
78
CompressedTensorsScheme)
89
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
@@ -14,12 +15,16 @@
1415
ModelWeightParameter,
1516
PerTensorScaleParameter)
1617

18+
logger = init_logger(__name__)
19+
1720

1821
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
1922

20-
def __init__(self, strategy: str, is_static_input_scheme: bool):
23+
def __init__(self, strategy: str, is_static_input_scheme: bool,
24+
input_symmetric: bool):
2125
self.strategy = strategy
2226
self.is_static_input_scheme = is_static_input_scheme
27+
self.input_symmetric = input_symmetric
2328

2429
@classmethod
2530
def get_min_capability(cls) -> int:
@@ -46,10 +51,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
4651
requires_grad=False)
4752
# INPUT SCALE
4853
if self.is_static_input_scheme:
49-
layer.input_scale = Parameter(layer.input_scale.max(),
50-
requires_grad=False)
54+
if self.input_symmetric:
55+
layer.input_scale = Parameter(layer.input_scale.max(),
56+
requires_grad=False)
57+
layer.input_zero_point = None
58+
else:
59+
# reconstruct the ranges
60+
int8_traits = torch.iinfo(torch.int8)
61+
azps = layer.input_zero_point.to(dtype=torch.int32)
62+
range_max = (layer.input_scale *
63+
(int8_traits.max - azps)).max()
64+
range_min = (layer.input_scale *
65+
(int8_traits.min - azps)).min()
66+
67+
scale = (range_max - range_min) / (int8_traits.max -
68+
int8_traits.min)
69+
layer.input_scale = Parameter(scale, requires_grad=False)
70+
71+
# AZP loaded as int8 but used as int32
72+
azp = (int8_traits.min -
73+
range_min / scale).to(dtype=torch.int32)
74+
layer.input_zero_point = Parameter(azp, requires_grad=False)
75+
5176
else:
5277
layer.input_scale = None
78+
layer.input_zero_point = None
79+
80+
# azp_adj is the AZP adjustment term, used to account for weights.
81+
# It does not depend on scales or azp, so it is the same for
82+
# static and dynamic quantization.
83+
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
84+
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
85+
if not self.input_symmetric:
86+
layer.azp_adj = layer.weight.sum(dim=0,
87+
keepdim=True,
88+
dtype=torch.int32)
89+
else:
90+
layer.azp_adj = None
5391

5492
def create_weights(self, layer: torch.nn.Module,
5593
output_partition_sizes: List[int],
@@ -90,11 +128,22 @@ def create_weights(self, layer: torch.nn.Module,
90128
weight_loader=weight_loader)
91129
layer.register_parameter("input_scale", input_scale)
92130

131+
if not self.input_symmetric:
132+
# Note: compressed-tensors stores the zp using the same dtype
133+
# as the weights
134+
# AZP loaded as int8 but used as int32
135+
input_zero_point = BasevLLMParameter(
136+
data=torch.empty(1, dtype=torch.int8),
137+
weight_loader=weight_loader)
138+
layer.register_parameter("input_zero_point", input_zero_point)
139+
93140
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
94141
bias: Optional[torch.Tensor]) -> torch.Tensor:
95142

96143
return apply_int8_linear(input=x,
97144
weight=layer.weight,
98145
weight_scale=layer.weight_scale,
99146
input_scale=layer.input_scale,
147+
input_zero_point=layer.input_zero_point,
148+
azp_adj=layer.azp_adj,
100149
bias=bias)

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,28 @@ def apply_int8_linear(
191191
weight: torch.Tensor,
192192
weight_scale: torch.Tensor,
193193
input_scale: Optional[torch.Tensor] = None,
194+
input_zero_point: Optional[torch.Tensor] = None,
195+
azp_adj: Optional[torch.Tensor] = None,
194196
bias: Optional[torch.Tensor] = None,
195197
):
196198
# ops.scaled_int8_quant supports both dynamic and static quant.
197199
# * dynamic, layer.input_scale is None and x_scale computed from x.
198200
# * static, layer.input_scale is scalar and x_scale is input_scale.
199-
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
200-
201+
symmetric = azp_adj is None
202+
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
203+
input_scale,
204+
input_zero_point,
205+
symmetric=symmetric)
206+
207+
if x_zp is not None:
208+
return ops.cutlass_scaled_mm_azp(x_q,
209+
weight,
210+
scale_a=x_scale,
211+
scale_b=weight_scale,
212+
out_dtype=input.dtype,
213+
azp_adj=azp_adj,
214+
azp=x_zp,
215+
bias=bias)
201216
return ops.cutlass_scaled_mm(x_q,
202217
weight,
203218
scale_a=x_scale,

0 commit comments

Comments
 (0)