File tree 4 files changed +20
-3
lines changed
tests/models/quantization
4 files changed +20
-3
lines changed Original file line number Diff line number Diff line change 2
2
import pytest
3
3
4
4
from tests .quantization .utils import is_quant_method_supported
5
+ from vllm .platforms import current_platform
5
6
6
7
# These ground truth generations were generated using `transformers==4.38.1
7
8
# aqlm==1.1.0 torch==2.2.0`
34
35
]
35
36
36
37
37
- @pytest .mark .skipif (not is_quant_method_supported ("aqlm" ),
38
+ @pytest .mark .skipif (not is_quant_method_supported ("aqlm" )
39
+ or current_platform .is_rocm ()
40
+ or not current_platform .is_cuda (),
38
41
reason = "AQLM is not supported on this GPU type." )
39
42
@pytest .mark .parametrize ("model" , ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" ])
40
43
@pytest .mark .parametrize ("dtype" , ["half" ])
Original file line number Diff line number Diff line change @@ -55,6 +55,14 @@ def test_models(
55
55
Only checks log probs match to cover the discrepancy in
56
56
numerical sensitive kernels.
57
57
"""
58
+
59
+ if backend == "FLASHINFER" and current_platform .is_rocm ():
60
+ pytest .skip ("Flashinfer does not support ROCm/HIP." )
61
+
62
+ if kv_cache_dtype == "fp8_e5m2" and current_platform .is_rocm ():
63
+ pytest .skip (
64
+ f"{ kv_cache_dtype } is currently not supported on ROCm/HIP." )
65
+
58
66
with monkeypatch .context () as m :
59
67
m .setenv ("TOKENIZERS_PARALLELISM" , 'true' )
60
68
m .setenv (STR_BACKEND_ENV_VAR , backend )
Original file line number Diff line number Diff line change 14
14
15
15
from tests .quantization .utils import is_quant_method_supported
16
16
from vllm .model_executor .layers .rotary_embedding import _ROPE_DICT
17
+ from vllm .platforms import current_platform
17
18
18
19
from ..utils import check_logprobs_close
19
20
34
35
35
36
36
37
@pytest .mark .flaky (reruns = 3 )
37
- @pytest .mark .skipif (not is_quant_method_supported ("gptq_marlin" ),
38
+ @pytest .mark .skipif (not is_quant_method_supported ("gptq_marlin" )
39
+ or current_platform .is_rocm ()
40
+ or not current_platform .is_cuda (),
38
41
reason = "gptq_marlin is not supported on this GPU type." )
39
42
@pytest .mark .parametrize ("model" , MODELS )
40
43
@pytest .mark .parametrize ("dtype" , ["half" , "bfloat16" ])
Original file line number Diff line number Diff line change 10
10
import pytest
11
11
12
12
from tests .quantization .utils import is_quant_method_supported
13
+ from vllm .platforms import current_platform
13
14
14
15
from ..utils import check_logprobs_close
15
16
@@ -38,7 +39,9 @@ class ModelPair:
38
39
39
40
40
41
@pytest .mark .flaky (reruns = 2 )
41
- @pytest .mark .skipif (not is_quant_method_supported ("gptq_marlin_24" ),
42
+ @pytest .mark .skipif (not is_quant_method_supported ("gptq_marlin_24" )
43
+ or current_platform .is_rocm ()
44
+ or not current_platform .is_cuda (),
42
45
reason = "Marlin24 is not supported on this GPU type." )
43
46
@pytest .mark .parametrize ("model_pair" , model_pairs )
44
47
@pytest .mark .parametrize ("dtype" , ["half" ])
You can’t perform that action at this time.
0 commit comments