Skip to content

Commit 300da09

Browse files
authored
[Kernel] Fullgraph and opcheck tests (#8479)
1 parent 1c04644 commit 300da09

26 files changed

+744
-116
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ steps:
7070
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
7171
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
7272
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
73-
73+
7474
- label: Core Test # 10min
7575
mirror_hardwares: [amd]
7676
fast_check: true
@@ -210,6 +210,21 @@ steps:
210210
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
211211
parallelism: 4
212212

213+
- label: "PyTorch Fullgraph Smoke Test"
214+
fast_check: true
215+
source_file_dependencies:
216+
- vllm/
217+
- tests/compile
218+
commands:
219+
- pytest -v -s compile/test_full_graph_smoke.py
220+
221+
- label: "PyTorch Fullgraph Test"
222+
source_file_dependencies:
223+
- vllm/
224+
- tests/compile
225+
commands:
226+
- pytest -v -s compile/test_full_graph.py
227+
213228
- label: Kernels Test %N # 30min each
214229
mirror_hardwares: [amd]
215230
source_file_dependencies:
@@ -355,7 +370,7 @@ steps:
355370
- tests/distributed/
356371
- vllm/compilation
357372
commands:
358-
- pytest -v -s ./compile/test_full_graph.py
373+
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
359374
- pytest -v -s ./compile/test_wrapper.py
360375
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
361376
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
586586
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
587587
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
588588
});
589-
std::vector<at::Tensor> result = {out, x.value()};
589+
std::vector<at::Tensor> result = {out};
590590
if (has_z) { result.push_back(out_z); }
591591
return result;
592592
}

csrc/torch_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
275275
"Tensor! A, Tensor! B, Tensor! C,"
276276
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
277277
"bool delta_softplus,"
278-
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
278+
"Tensor? index_, Tensor!? x) -> Tensor[]");
279279
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
280280

281281
ops.def(
@@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
292292
"Tensor? bias_,"
293293
"Tensor? seq_idx_,"
294294
"Tensor? initial_states_,"
295-
"Tensor? final_states_out_,"
295+
"Tensor!? final_states_out_,"
296296
"bool silu_activation) -> Tensor");
297297
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
298298
#endif

tests/compile/test_full_graph.py

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,13 @@
1-
import os
2-
31
import pytest
42

5-
from vllm.utils import cuda_device_count_stateless
6-
7-
from ..utils import fork_new_process_for_each_test
8-
9-
10-
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
11-
@pytest.mark.parametrize("tp_size", [1, 2])
12-
@fork_new_process_for_each_test
13-
def test_full_graph(model, tp_size):
14-
15-
# Skip the test if there are not enough CUDA devices.
16-
if cuda_device_count_stateless() < tp_size:
17-
pytest.skip("Not enough CUDA devices for the test.")
18-
19-
# make sure these models can be captured in full graph mode
20-
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
21-
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
3+
from vllm.compilation.backends import vllm_backend
224

23-
from vllm import LLM, SamplingParams
24-
prompts = [
25-
"Hello, my name is",
26-
"The president of the United States is",
27-
"The capital of France is",
28-
"The future of AI is",
29-
]
30-
sampling_params = SamplingParams(temperature=0)
31-
llm = LLM(model=model,
32-
enforce_eager=True,
33-
tensor_parallel_size=tp_size,
34-
disable_custom_all_reduce=True)
5+
from .utils import TEST_MODELS, check_full_graph_support
356

36-
outputs = llm.generate(prompts, sampling_params)
377

38-
# Print the outputs.
39-
for output in outputs:
40-
prompt = output.prompt
41-
generated_text = output.outputs[0].text
42-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
8+
@pytest.mark.parametrize("model_info", TEST_MODELS)
9+
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
10+
def test_full_graph(model_info, backend):
11+
model = model_info[0]
12+
model_kwargs = model_info[1]
13+
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
from vllm.compilation.backends import vllm_backend
4+
from vllm.utils import cuda_device_count_stateless
5+
6+
from ..utils import fork_new_process_for_each_test
7+
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
8+
9+
10+
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
11+
@pytest.mark.parametrize("tp_size", [2])
12+
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
13+
@fork_new_process_for_each_test
14+
def test_full_graph_multi_gpu(model_info, tp_size, backend):
15+
model = model_info[0]
16+
model_kwargs = model_info[1]
17+
18+
# Skip the test if there are not enough CUDA devices.
19+
if cuda_device_count_stateless() < tp_size:
20+
pytest.skip("Not enough CUDA devices for the test.")
21+
22+
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
3+
from vllm.compilation.backends import vllm_backend
4+
5+
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
6+
7+
8+
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
9+
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
10+
def test_full_graph(model_info, backend):
11+
model = model_info[0]
12+
model_kwargs = model_info[1]
13+
check_full_graph_support(model, model_kwargs, backend, tp_size=1)

tests/compile/utils.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import os
2+
3+
import torch
4+
5+
from tests.quantization.utils import is_quant_method_supported
6+
from vllm import LLM, SamplingParams
7+
from vllm.plugins import set_torch_compile_backend
8+
from vllm.utils import is_hip
9+
10+
TEST_MODELS_SMOKE = [
11+
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
12+
"quantization": "compressed-tensors"
13+
}),
14+
("meta-llama/Meta-Llama-3-8B", {}),
15+
]
16+
17+
TEST_MODELS = [
18+
("facebook/opt-125m", {}),
19+
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
20+
"dtype": torch.float16,
21+
"quantization": "compressed-tensors"
22+
}),
23+
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
24+
"dtype": torch.float16,
25+
"quantization": "fp8"
26+
}),
27+
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
28+
"quantization": "compressed-tensors"
29+
}),
30+
("meta-llama/Meta-Llama-3-8B", {}),
31+
]
32+
33+
# TODO: enable in pytorch 2.5
34+
if False and is_quant_method_supported("aqlm"): # noqa: SIM223
35+
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
36+
"quantization": "aqlm"
37+
}))
38+
39+
# TODO: enable in pytorch 2.5
40+
if False and is_quant_method_supported("gguf"): # noqa: SIM223
41+
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
42+
"quantization": "gguf"
43+
}))
44+
45+
if is_quant_method_supported("gptq"):
46+
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
47+
"quantization": "gptq"
48+
}))
49+
50+
if is_quant_method_supported("gptq_marlin"):
51+
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
52+
"quantization": "gptq_marlin"
53+
}))
54+
55+
if is_quant_method_supported("gptq_marlin_24"):
56+
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
57+
"quantization": "gptq_marlin_24"
58+
}))
59+
60+
if is_quant_method_supported("marlin"):
61+
TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
62+
"quantization": "marlin"
63+
}))
64+
65+
if not is_hip() and is_quant_method_supported("awq"):
66+
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
67+
"quantization": "AWQ"
68+
}))
69+
70+
71+
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
72+
# make sure these models can be captured in full graph mode
73+
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
74+
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
75+
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
76+
77+
# Inductor doesn't support fp8/gptq_marlin_24 yet.
78+
quantization = model_kwargs.get("quantization")
79+
if (quantization == "fp8" or quantization == "gptq_marlin"
80+
or quantization == "gptq_marlin_24") and backend != "eager":
81+
return
82+
83+
set_torch_compile_backend(backend)
84+
85+
prompts = [
86+
"Hello, my name is",
87+
"The president of the United States is",
88+
"The capital of France is",
89+
"The future of AI is",
90+
]
91+
sampling_params = SamplingParams(temperature=0)
92+
llm = LLM(model=model,
93+
enforce_eager=True,
94+
tensor_parallel_size=tp_size,
95+
disable_custom_all_reduce=True,
96+
**model_kwargs)
97+
98+
outputs = llm.generate(prompts, sampling_params)
99+
100+
# Print the outputs.
101+
for output in outputs:
102+
prompt = output.prompt
103+
generated_text = output.outputs[0].text
104+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
169169
cleanup()
170170

171171

172+
@pytest.fixture(autouse=True)
173+
def dynamo_reset():
174+
yield
175+
torch._dynamo.reset()
176+
177+
172178
@pytest.fixture
173179
def example_prompts() -> List[str]:
174180
prompts = []

tests/kernels/test_aqlm.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
3+
from tests.kernels.utils import opcheck
4+
from vllm import _custom_ops as ops # noqa: F401
5+
6+
7+
def test_aqlm_dequant_opcheck():
8+
codes = torch.randint(-32768,
9+
32767, (22016, 512, 1),
10+
device='cuda',
11+
dtype=torch.int16)
12+
codebooks = torch.rand((2, 65536, 1, 8),
13+
device='cuda',
14+
dtype=torch.float16)
15+
codebook_partition_sizes = [11008, 11008]
16+
17+
opcheck(torch.ops._C.aqlm_dequant,
18+
(codes, codebooks, codebook_partition_sizes))
19+
20+
21+
def test_aqlm_gemm_opcheck():
22+
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16)
23+
codes = torch.randint(-32768,
24+
32767, (12288, 512, 1),
25+
device='cuda',
26+
dtype=torch.int16)
27+
codebooks = torch.rand((3, 65536, 1, 8),
28+
device='cuda',
29+
dtype=torch.float16)
30+
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16)
31+
codebook_partition_sizes = [4096, 4096, 4096]
32+
bias = None
33+
34+
opcheck(torch.ops._C.aqlm_gemm,
35+
(input, codes, codebooks, scales, codebook_partition_sizes, None))
36+
opcheck(torch.ops._C.aqlm_gemm,
37+
(input, codes, codebooks, scales, codebook_partition_sizes, bias))

tests/kernels/test_attention.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def test_paged_attention(
205205
(output, query, key_cache, value_cache, num_kv_heads, scale,
206206
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
207207
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
208-
cond=(head_size == HEAD_SIZES[0]))
208+
cond=(head_size == HEAD_SIZES[0]
209+
and block_size == BLOCK_SIZES[0]))
209210

210211
elif version in ("v2", "rocm"):
211212
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
@@ -246,7 +247,8 @@ def test_paged_attention(
246247
key_cache, value_cache, num_kv_heads, scale, block_tables,
247248
seq_lens, block_size, max_seq_len, alibi_slopes,
248249
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
249-
cond=(head_size == HEAD_SIZES[0]))
250+
cond=(head_size == HEAD_SIZES[0]
251+
and block_size == BLOCK_SIZES[0]))
250252

251253
else:
252254
ops.paged_attention_rocm(
@@ -274,7 +276,8 @@ def test_paged_attention(
274276
key_cache, value_cache, num_kv_heads, scale, block_tables,
275277
seq_lens, block_size, max_seq_len, alibi_slopes,
276278
kv_cache_dtype, k_scale, v_scale),
277-
cond=(head_size == HEAD_SIZES[0]))
279+
cond=(head_size == HEAD_SIZES[0]
280+
and block_size == BLOCK_SIZES[0]))
278281

279282
else:
280283
raise AssertionError(f"Unknown version: {version}")

tests/kernels/test_awq.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
3+
import torch
4+
5+
from tests.kernels.utils import opcheck
6+
from vllm import _custom_ops as ops # noqa: F401
7+
8+
9+
def test_awq_dequantize_opcheck():
10+
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
11+
qweight = torch.randint(-2000000000,
12+
2000000000, (8192, 256),
13+
device='cuda',
14+
dtype=torch.int32)
15+
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
16+
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
17+
split_k_iters = 0
18+
thx = 0
19+
thy = 0
20+
opcheck(torch.ops._C.awq_dequantize,
21+
(qweight, scales, zeros, split_k_iters, thx, thy))
22+
23+
24+
def test_awq_gemm_opcheck():
25+
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
26+
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
27+
qweight = torch.randint(-2000000000,
28+
2000000000, (8192, 256),
29+
device='cuda',
30+
dtype=torch.int32)
31+
scales = torch.randint(-2000000000,
32+
2000000000, (64, 256),
33+
device='cuda',
34+
dtype=torch.int32)
35+
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
36+
split_k_iters = 8
37+
opcheck(torch.ops._C.awq_gemm,
38+
(input, qweight, qzeros, scales, split_k_iters))

0 commit comments

Comments
 (0)