Skip to content

Commit f6518b2

Browse files
[ROCm] Skip tests for quantizations incompatible with ROCm (#17905)
Signed-off-by: Hissu Hyvarinen <hissu.hyvarinen@amd.com>
1 parent d67085c commit f6518b2

File tree

4 files changed

+20
-3
lines changed

4 files changed

+20
-3
lines changed

tests/models/quantization/test_aqlm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from tests.quantization.utils import is_quant_method_supported
5+
from vllm.platforms import current_platform
56

67
# These ground truth generations were generated using `transformers==4.38.1
78
# aqlm==1.1.0 torch==2.2.0`
@@ -34,7 +35,9 @@
3435
]
3536

3637

37-
@pytest.mark.skipif(not is_quant_method_supported("aqlm"),
38+
@pytest.mark.skipif(not is_quant_method_supported("aqlm")
39+
or current_platform.is_rocm()
40+
or not current_platform.is_cuda(),
3841
reason="AQLM is not supported on this GPU type.")
3942
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
4043
@pytest.mark.parametrize("dtype", ["half"])

tests/models/quantization/test_fp8.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def test_models(
5555
Only checks log probs match to cover the discrepancy in
5656
numerical sensitive kernels.
5757
"""
58+
59+
if backend == "FLASHINFER" and current_platform.is_rocm():
60+
pytest.skip("Flashinfer does not support ROCm/HIP.")
61+
62+
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
63+
pytest.skip(
64+
f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
65+
5866
with monkeypatch.context() as m:
5967
m.setenv("TOKENIZERS_PARALLELISM", 'true')
6068
m.setenv(STR_BACKEND_ENV_VAR, backend)

tests/models/quantization/test_gptq_marlin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from tests.quantization.utils import is_quant_method_supported
1616
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
17+
from vllm.platforms import current_platform
1718

1819
from ..utils import check_logprobs_close
1920

@@ -34,7 +35,9 @@
3435

3536

3637
@pytest.mark.flaky(reruns=3)
37-
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
38+
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin")
39+
or current_platform.is_rocm()
40+
or not current_platform.is_cuda(),
3841
reason="gptq_marlin is not supported on this GPU type.")
3942
@pytest.mark.parametrize("model", MODELS)
4043
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])

tests/models/quantization/test_gptq_marlin_24.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
from tests.quantization.utils import is_quant_method_supported
13+
from vllm.platforms import current_platform
1314

1415
from ..utils import check_logprobs_close
1516

@@ -38,7 +39,9 @@ class ModelPair:
3839

3940

4041
@pytest.mark.flaky(reruns=2)
41-
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
42+
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24")
43+
or current_platform.is_rocm()
44+
or not current_platform.is_cuda(),
4245
reason="Marlin24 is not supported on this GPU type.")
4346
@pytest.mark.parametrize("model_pair", model_pairs)
4447
@pytest.mark.parametrize("dtype", ["half"])

0 commit comments

Comments
 (0)