Skip to content

Commit b3195bc

Browse files
gshtrasAlexei-V-Ivanov-AMDmgoin
authored
[AMD][ROCm]Quantization methods on ROCm; Fix _scaled_mm call (#8380)
Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
1 parent e18749f commit b3195bc

File tree

4 files changed

+71
-27
lines changed

4 files changed

+71
-27
lines changed

vllm/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,10 @@ def _parse_quant_hf_config(self):
255255

256256
def _verify_quantization(self) -> None:
257257
supported_quantization = [*QUANTIZATION_METHODS]
258-
rocm_supported_quantization = ["awq", "gptq", "fp8"]
258+
rocm_supported_quantization = [
259+
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
260+
"fbgemm_fp8"
261+
]
259262
optimized_quantization_methods = [
260263
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
261264
"awq_marlin", "fbgemm_fp8", "compressed_tensors",

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
99
QuantizationStrategy)
1010
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
11-
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
11+
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
12+
requantize_with_max_scale)
1213
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
1314
ModelWeightParameter,
1415
PerTensorScaleParameter)
16+
from vllm.utils import is_hip
1517

1618
__all__ = ["CompressedTensorsW8A8Fp8"]
1719

@@ -39,16 +41,37 @@ def process_weights_after_loading(self, layer) -> None:
3941
logical_widths=layer.logical_widths,
4042
)
4143

44+
if is_hip():
45+
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
46+
weight=weight,
47+
weight_scale=max_w_scale,
48+
input_scale=layer.input_scale)
49+
if input_scale is not None:
50+
layer.input_scale = Parameter(input_scale,
51+
requires_grad=False)
52+
4253
layer.weight = Parameter(weight.t(), requires_grad=False)
4354
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
4455

4556
# If channelwise, scales are already lined up, so just transpose.
4657
elif self.strategy == QuantizationStrategy.CHANNEL:
4758
weight = layer.weight
59+
60+
if is_hip():
61+
weight, weight_scale, input_scale = \
62+
normalize_e4m3fn_to_e4m3fnuz(
63+
weight=weight,
64+
weight_scale=layer.weight_scale,
65+
input_scale=layer.input_scale)
66+
if input_scale is not None:
67+
layer.input_scale = Parameter(input_scale,
68+
requires_grad=False)
69+
else:
70+
weight_scale = layer.weight_scale.data
71+
4872
layer.weight = Parameter(weight.t(), requires_grad=False)
4973
# required by torch.compile to be torch.nn.Parameter
50-
layer.weight_scale = Parameter(layer.weight_scale.data,
51-
requires_grad=False)
74+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
5275

5376
else:
5477
raise ValueError(f"Unknown quantization strategy {self.strategy}")

vllm/model_executor/layers/quantization/fbgemm_fp8.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1616
is_layer_skipped)
1717
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
18-
apply_fp8_linear)
18+
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
1919
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
2020
ModelWeightParameter)
2121
from vllm.platforms import current_platform
22+
from vllm.utils import is_hip
2223

2324
logger = init_logger(__name__)
2425

@@ -125,8 +126,18 @@ def process_weights_after_loading(self, layer: Module) -> None:
125126
layer.weight = Parameter(layer.weight.data, requires_grad=False)
126127

127128
weight = layer.weight
128-
layer.weight = Parameter(weight.t(), requires_grad=False)
129129

130+
if is_hip():
131+
weight, weight_scale, input_scale = \
132+
normalize_e4m3fn_to_e4m3fnuz(
133+
weight=weight,
134+
weight_scale=layer.weight_scale,
135+
input_scale=None)
136+
if input_scale is not None:
137+
layer.input_scale = Parameter(input_scale, requires_grad=False)
138+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
139+
140+
layer.weight = Parameter(weight.t(), requires_grad=False)
130141
if self.quant_config.use_marlin:
131142
prepare_fp8_layer_for_marlin(layer)
132143
# Activations not quantized for marlin.

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
from vllm.platforms import current_platform
77
from vllm.utils import is_hip
88

9-
# scaled_mm in pytorch on rocm has a bug that requires always
10-
# providing scaling factor for result. This value is created
11-
# as global value to avoid multiple tensor allocations, and
12-
# can be removed once pytorch fixes the bug.
13-
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
9+
# Input scaling factors are no longer optional in _scaled_mm starting
10+
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
11+
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
1412

1513

1614
def cutlass_fp8_supported() -> bool:
@@ -131,19 +129,17 @@ def apply_fp8_linear(
131129

132130
if per_tensor_weights and per_tensor_activations:
133131
# Fused GEMM_DQ
134-
output = torch._scaled_mm(
135-
qinput,
136-
weight,
137-
out_dtype=input.dtype,
138-
scale_a=x_scale,
139-
scale_b=weight_scale,
140-
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
141-
bias=bias)
142-
# Since in torch 2.5, scaled_mm only returns single value
143-
# This should be removed when vllm-nvidia also moves to 2.5
144-
if is_hip():
145-
return torch.narrow(output, 0, 0, input.shape[0])
146-
return torch.narrow(output[0], 0, 0, input.shape[0])
132+
output = torch._scaled_mm(qinput,
133+
weight,
134+
out_dtype=input.dtype,
135+
scale_a=x_scale,
136+
scale_b=weight_scale,
137+
bias=bias)
138+
# A fix for discrepancy in scaled_mm which returns tuple
139+
# for torch < 2.5 and a single value in torch >= 2.5
140+
if type(output) is tuple and len(output) == 2:
141+
return torch.narrow(output[0], 0, 0, input.shape[0])
142+
return torch.narrow(output, 0, 0, input.shape[0])
147143

148144
else:
149145
# Fallback for channelwise case, where we use unfused DQ
@@ -161,12 +157,23 @@ def apply_fp8_linear(
161157
# For the scaled_mm fallback case, we break this down, since it
162158
# does not support s_w being a vector.
163159

160+
# Making sure the dummy tensor is on the same device as the weight
161+
global TORCH_DEVICE_IDENTITY
162+
if TORCH_DEVICE_IDENTITY.device != weight.device:
163+
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
164+
164165
# GEMM
165166
# This computes C = (X * W).
166167
# Output in fp32 to allow subsequent ops to happen in-place
167-
output, _ = torch._scaled_mm(qinput,
168-
weight,
169-
out_dtype=torch.float32)
168+
output = torch._scaled_mm(qinput,
169+
weight,
170+
scale_a=TORCH_DEVICE_IDENTITY,
171+
scale_b=TORCH_DEVICE_IDENTITY,
172+
out_dtype=torch.float32)
173+
# A fix for discrepancy in scaled_mm which returns tuple
174+
# for torch < 2.5 and a single value in torch >= 2.5
175+
if type(output) is tuple and len(output) == 2:
176+
output = output[0]
170177
# Unpad (undo num_token_padding)
171178
output = torch.narrow(output, 0, 0, input.shape[0])
172179
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])

0 commit comments

Comments
 (0)