Skip to content

Commit 1570203

Browse files
[Spec Decode] (1/2) Remove batch expansion (#8839)
1 parent 22f5851 commit 1570203

29 files changed

+531
-99
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ steps:
208208
- tests/spec_decode
209209
commands:
210210
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
211-
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
211+
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
212212

213213
- label: LoRA Test %N # 15min each
214214
mirror_hardwares: [amd]

tests/samplers/test_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def run_test_case(*, expected_penalization: List[bool],
434434
sampling_metadata = SamplingMetadata.prepare(
435435
seq_group_metadata_list,
436436
seq_lens=seq_lens if seq_lens else None,
437-
query_lens=seq_lens if seq_lens else None,
437+
query_lens=seq_lens if seq_lens else [1] * batch_size,
438438
device=device,
439439
pin_memory=is_pin_memory_available())
440440
# the logits tensor is modified in-place by the sampler

tests/spec_decode/e2e/test_integration.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
102102
max_output_len=32,
103103
seed=seed,
104104
temperature=0.0)
105+
106+
107+
@pytest.mark.parametrize(
108+
"common_llm_kwargs",
109+
[{
110+
"model_name": MAIN_MODEL,
111+
112+
# Skip cuda graph recording for fast test.
113+
"enforce_eager": True,
114+
115+
# Required for spec decode.
116+
"use_v2_block_manager": True,
117+
"speculative_model": "JackFram/llama-68m",
118+
"num_speculative_tokens": 3,
119+
}])
120+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
121+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
122+
@pytest.mark.parametrize("test_llm_kwargs",
123+
[{
124+
"speculative_disable_mqa_scorer": True,
125+
}])
126+
@pytest.mark.parametrize("batch_size", [1, 5])
127+
@pytest.mark.parametrize(
128+
"output_len",
129+
[
130+
# Use smaller output len for fast test.
131+
32,
132+
])
133+
@pytest.mark.parametrize("seed", [1])
134+
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
135+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
136+
output_len: int, seed: int):
137+
"""Verify that ngram speculative decoding generates the same output
138+
with batch expansion scorer and mqa scorer.
139+
"""
140+
run_equality_correctness_test(vllm_runner,
141+
common_llm_kwargs,
142+
per_test_common_llm_kwargs,
143+
baseline_llm_kwargs,
144+
test_llm_kwargs,
145+
batch_size,
146+
max_output_len=output_len,
147+
seed=seed,
148+
temperature=0.0)

tests/spec_decode/e2e/test_medusa_correctness.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
350350
temperature=0.0)
351351

352352

353+
@pytest.mark.parametrize(
354+
"common_llm_kwargs",
355+
[{
356+
# Skip cuda graph recording for fast test.
357+
"enforce_eager": True,
358+
359+
# Required for spec decode.
360+
"use_v2_block_manager": True,
361+
362+
# Precision
363+
"dtype": PRECISION,
364+
365+
# Main model
366+
"model_name": MAIN_MODEL,
367+
"speculative_model": SPEC_MODEL,
368+
"num_speculative_tokens": MAX_SPEC_TOKENS,
369+
"speculative_disable_by_batch_size": 4
370+
}])
371+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
372+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
373+
@pytest.mark.parametrize("test_llm_kwargs",
374+
[{
375+
"speculative_disable_mqa_scorer": True,
376+
}])
377+
@pytest.mark.parametrize("batch_size", [1, 5])
378+
@pytest.mark.parametrize(
379+
"output_len",
380+
[
381+
# Use smaller output len for fast test.
382+
32,
383+
])
384+
@pytest.mark.parametrize("seed", [1])
385+
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
386+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
387+
output_len: int, seed: int):
388+
"""Verify that speculative decoding generates the same output
389+
with batch expansion scorer and mqa scorer.
390+
"""
391+
run_equality_correctness_test(vllm_runner,
392+
common_llm_kwargs,
393+
per_test_common_llm_kwargs,
394+
baseline_llm_kwargs,
395+
test_llm_kwargs,
396+
batch_size,
397+
max_output_len=output_len,
398+
seed=seed,
399+
temperature=0.0)
400+
401+
353402
if __name__ == "__main__":
354403
import pytest
355404
pytest.main([__file__])

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
460460
max_output_len=output_len,
461461
seed=seed,
462462
temperature=0.0)
463+
464+
465+
@pytest.mark.parametrize(
466+
"common_llm_kwargs",
467+
[{
468+
"model_name": MAIN_MODEL,
469+
470+
# Skip cuda graph recording for fast test.
471+
"enforce_eager": True,
472+
473+
# Required for spec decode.
474+
"use_v2_block_manager": True,
475+
"speculative_model": SPEC_MODEL,
476+
}])
477+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
478+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
479+
@pytest.mark.parametrize("test_llm_kwargs",
480+
[{
481+
"speculative_disable_mqa_scorer": True,
482+
}])
483+
@pytest.mark.parametrize("batch_size", [1, 5])
484+
@pytest.mark.parametrize(
485+
"output_len",
486+
[
487+
# Use smaller output len for fast test.
488+
32,
489+
])
490+
@pytest.mark.parametrize("seed", [1])
491+
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
492+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
493+
output_len: int, seed: int):
494+
"""Verify that speculative decoding generates the same output
495+
with batch expansion scorer and mqa scorer.
496+
"""
497+
run_equality_correctness_test(vllm_runner,
498+
common_llm_kwargs,
499+
per_test_common_llm_kwargs,
500+
baseline_llm_kwargs,
501+
test_llm_kwargs,
502+
batch_size,
503+
max_output_len=output_len,
504+
seed=seed,
505+
temperature=0.0)

tests/spec_decode/e2e/test_ngram_correctness.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
292292
max_output_len=output_len,
293293
seed=seed,
294294
temperature=0.0)
295+
296+
297+
@pytest.mark.parametrize(
298+
"common_llm_kwargs",
299+
[{
300+
"model_name": "JackFram/llama-68m",
301+
302+
# Skip cuda graph recording for fast test.
303+
"enforce_eager": True,
304+
305+
# Required for spec decode.
306+
"use_v2_block_manager": True,
307+
"speculative_model": "[ngram]",
308+
"num_speculative_tokens": 5,
309+
"ngram_prompt_lookup_max": 3,
310+
}])
311+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
312+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
313+
@pytest.mark.parametrize("test_llm_kwargs",
314+
[{
315+
"speculative_disable_mqa_scorer": True,
316+
}])
317+
@pytest.mark.parametrize("batch_size", [1, 5])
318+
@pytest.mark.parametrize(
319+
"output_len",
320+
[
321+
# Use smaller output len for fast test.
322+
32,
323+
])
324+
@pytest.mark.parametrize("seed", [1])
325+
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
326+
per_test_common_llm_kwargs, baseline_llm_kwargs,
327+
test_llm_kwargs, batch_size: int, output_len: int,
328+
seed: int):
329+
"""Verify that ngram speculative decoding generates the same output
330+
with batch expansion scorer and mqa scorer.
331+
"""
332+
run_equality_correctness_test(vllm_runner,
333+
common_llm_kwargs,
334+
per_test_common_llm_kwargs,
335+
baseline_llm_kwargs,
336+
test_llm_kwargs,
337+
batch_size,
338+
max_output_len=output_len,
339+
seed=seed,
340+
temperature=0.0)

tests/spec_decode/test_multi_step_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
173173
block_size,
174174
num_gpu_blocks,
175175
seed,
176-
model_runner_cls=TP1DraftModelRunner,
177176
)
178177

179178
worker = create_worker(

tests/spec_decode/test_scorer.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
import torch
3+
4+
from vllm.sequence import ExecuteModelRequest
5+
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
6+
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
7+
from vllm.spec_decode.mqa_scorer import MQAScorer
8+
from vllm.worker.worker import Worker
9+
10+
from .utils import create_batch, create_worker
11+
12+
13+
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
14+
device: str) -> SpeculativeProposals:
15+
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
16+
device=device)
17+
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
18+
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
19+
return SpeculativeProposals(proposal_token_ids, proposal_probs,
20+
proposal_lens)
21+
22+
23+
def assert_score_equal(score1: SpeculativeScores,
24+
score2: SpeculativeScores) -> None:
25+
assert torch.allclose(score1.probs, score2.probs)
26+
assert torch.allclose(score1.logprobs, score2.logprobs)
27+
assert torch.equal(score1.token_ids, score2.token_ids)
28+
29+
30+
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
31+
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
32+
@pytest.mark.parametrize('propose_len', [1, 3, 5])
33+
@pytest.mark.parametrize('device', ['cuda'])
34+
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
35+
device: str) -> None:
36+
"""
37+
Compare the batch expansion scorer and mqa scorer return the same score
38+
"""
39+
seed = 0
40+
block_size = 32
41+
num_gpu_blocks = 2048 // block_size
42+
scorer_worker = create_worker(Worker, model_name, block_size,
43+
num_gpu_blocks, seed)
44+
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
45+
scorer_worker.model_runner.model.sampler.\
46+
should_modify_greedy_probs_inplace = True
47+
48+
vocab_size = scorer_worker.vocab_size
49+
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
50+
seq_group_metadatalist, _, _ = create_batch(batch_size,
51+
propose_len,
52+
block_size=block_size,
53+
num_gpu_blocks=num_gpu_blocks)
54+
requests = ExecuteModelRequest(seq_group_metadatalist,
55+
num_lookahead_slots=propose_len)
56+
57+
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
58+
vocab_size)
59+
batch_expansion_score = batch_expansion_scorer.score_proposals(
60+
requests, proposals)
61+
62+
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
63+
mqa_score = mqa_scorer.score_proposals(requests, proposals)
64+
65+
assert_score_equal(batch_expansion_score, mqa_score)

tests/spec_decode/test_spec_decode_worker.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
6363
@pytest.mark.parametrize("acceptance_sampler_method",
6464
["rejection_sampler", "typical_acceptance_sampler"])
6565
@torch.inference_mode()
66-
def test_correctly_calls_target_model(k: int, batch_size: int,
67-
acceptance_sampler_method: str):
66+
def test_batch_expansion_correctly_calls_target_model(
67+
k: int, batch_size: int, acceptance_sampler_method: str):
6868
"""Verify SpecDecodeWorker calls the target model with correct
69-
inputs. Everything else is mocked out.
69+
inputs with batch expansion. Everything else is mocked out.
7070
"""
7171
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
7272
target_worker = mock_worker(use_spec=False)
@@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
8282
target_worker,
8383
mock_spec_decode_sampler(acceptance_sampler_method),
8484
disable_logprobs=False,
85-
metrics_collector=metrics_collector)
85+
metrics_collector=metrics_collector,
86+
disable_mqa_scorer=True)
8687
worker.init_device()
8788

8889
vocab_size = 32_000

tests/spec_decode/utils.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
131131
for i, final_len in enumerate(final_prompt_lens)
132132
}
133133

134-
return [
135-
SequenceGroupMetadata(
136-
request_id=str(i),
137-
is_prompt=len(cont_token_ids) == 0,
138-
seq_data={
139-
i: SequenceData.from_seqs(prompt_token_ids[:],
140-
cont_token_ids[:]),
141-
},
142-
sampling_params=SamplingParams(temperature=0.0, ),
143-
block_tables={i: block_allocations[i][:]},
144-
) for i, (prompt_token_ids,
145-
cont_token_ids) in enumerate(zip(prompts, continuations))
146-
]
134+
seq_grou_metadata_list = []
135+
for i, (prompt_token_ids,
136+
cont_token_ids) in enumerate(zip(prompts, continuations)):
137+
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
138+
data.update_num_computed_tokens(
139+
len(prompt_token_ids) + len(cont_token_ids) - 1)
140+
seq_data = {i: data}
141+
seq_grou_metadata_list.append(
142+
SequenceGroupMetadata(
143+
request_id=str(i),
144+
is_prompt=len(cont_token_ids) == 0,
145+
seq_data=seq_data,
146+
sampling_params=SamplingParams(temperature=0.0),
147+
block_tables={i: block_allocations[i][:]},
148+
))
149+
return seq_grou_metadata_list
147150

148151

149152
def assert_logprobs_dict_allclose(

vllm/attention/backends/blocksparse_attn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
186186
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
187187
use_cuda_graph: bool
188188

189+
# Number of query tokens for each request in the batch.
190+
# Currently, we require that all requests have the same number of query
191+
# tokens during the decoding phase. When speculavie decoding is enabled,
192+
# decode_query_len might be greater than 1. In all other cases, it is 1.
193+
decode_query_len: Optional[int] = None
194+
189195
_cached_prefill_metadata: Optional[
190196
"BlocksparseFlashAttentionMetadata"] = None
191197
_cached_decode_metadata: Optional[

0 commit comments

Comments
 (0)