Skip to content

Commit ea259b0

Browse files
fxmarty-amdkewang2
authored andcommitted
[Bugfix] Fix quark fp8 format loading on AMD GPUs (vllm-project#12612)
Signed-off-by: Felix Marty <felmarty@amd.com> Signed-off-by: kewang2 <kewang2@amd.com> Co-authored-by: kewang2 <kewang2@amd.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
1 parent f50aeba commit ea259b0

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

tests/quantization/test_quark.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import pytest
8+
import torch
89

910
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
1011
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
@@ -63,3 +64,28 @@ def check_model(model):
6364

6465
output = llm.generate_greedy("Hello my name is", max_tokens=20)
6566
assert output
67+
68+
69+
def test_quark_fp8_parity(vllm_runner):
70+
quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
71+
fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"
72+
73+
llm_kwargs = {
74+
"tensor_parallel_size": 1,
75+
"enforce_eager": True,
76+
"gpu_memory_utilization": 0.1
77+
}
78+
with (vllm_runner(quark_model_id, **llm_kwargs) as
79+
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
80+
quark_model = (quark_handle.model.llm_engine.model_executor.
81+
driver_worker.model_runner.model)
82+
quark_state_dict = quark_model.state_dict()
83+
84+
fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
85+
model_runner.model)
86+
fp8_state_dict = fp8_model.state_dict()
87+
88+
assert fp8_state_dict.keys() == quark_state_dict.keys()
89+
90+
for key in fp8_state_dict:
91+
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])

vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,24 @@ def process_weights_after_loading(self, layer) -> None:
3434
# tensor scales (thus N scales being passed to the kernel),
3535
# requantize so we can always run per tensor
3636
if self.qscheme == "per_tensor":
37-
max_w_scale, weight = requantize_with_max_scale(
38-
weight=layer.weight,
39-
weight_scale=layer.weight_scale,
40-
logical_widths=layer.logical_widths,
41-
)
42-
43-
if current_platform.is_fp8_fnuz():
37+
if current_platform.is_rocm():
4438
input_scale = getattr(layer, 'input_scale', None)
4539
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
46-
weight=weight,
47-
weight_scale=max_w_scale,
40+
weight=layer.weight,
41+
weight_scale=layer.weight_scale,
4842
input_scale=input_scale)
4943
if input_scale is not None:
5044
layer.input_scale = Parameter(input_scale,
5145
requires_grad=False)
46+
else:
47+
max_w_scale = layer.weight_scale
48+
weight = layer.weight
49+
50+
max_w_scale, weight = requantize_with_max_scale(
51+
weight=weight,
52+
weight_scale=max_w_scale,
53+
logical_widths=layer.logical_widths,
54+
)
5255

5356
layer.weight = Parameter(weight.t(), requires_grad=False)
5457
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

0 commit comments

Comments
 (0)