Skip to content

Commit 95f8800

Browse files
XWFAlonemengwei805JC-ut0
committed
[1/N][UT][v1 MTP] add basic v1 mtp features
Co-authored-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com> Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
1 parent f6e5dec commit 95f8800

File tree

6 files changed

+480
-17
lines changed

6 files changed

+480
-17
lines changed

.github/workflows/vllm_ascend_test_long_term.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ jobs:
9393
- name: Run vllm-project/vllm-ascend long term test
9494
run: |
9595
# spec decode test
96+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
9697
VLLM_USE_MODELSCOPE=true pytest -sv tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
9798
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
98-
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
99+
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from __future__ import annotations
2+
3+
import random
4+
from typing import Any
5+
6+
import pytest
7+
from vllm import LLM, SamplingParams
8+
9+
10+
@pytest.fixture
11+
def test_prompts():
12+
prompt_types = ["repeat", "sentence"]
13+
num_prompts = 10
14+
prompts = []
15+
16+
random.seed(0)
17+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
18+
19+
# Generate a mixed batch of prompts, some of which can be easily
20+
# predicted by n-gram matching and some which likely cannot.
21+
for kind in random_prompt_type_choices:
22+
word_choices = ["test", "temp", "hello", "where"]
23+
word = random.choice(word_choices)
24+
if kind == "repeat":
25+
prompt = f"""
26+
please repeat the word '{word}' 10 times.
27+
give no other output than the word at least ten times in a row,
28+
in lowercase with spaces between each word and without quotes.
29+
"""
30+
elif kind == "sentence":
31+
prompt = f"""
32+
please give a ten-word sentence that
33+
uses the word {word} at least once.
34+
give no other output than that simple sentence without quotes.
35+
"""
36+
else:
37+
raise ValueError(f"Unknown prompt type: {kind}")
38+
prompts.append([{"role": "user", "content": prompt}])
39+
40+
return prompts
41+
42+
43+
@pytest.fixture
44+
def sampling_config():
45+
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
46+
47+
48+
@pytest.fixture
49+
def model_name():
50+
return "wemaster/deepseek_mtp_main_random_bf16"
51+
52+
53+
def test_mtp_correctness(
54+
monkeypatch: pytest.MonkeyPatch,
55+
test_prompts: list[list[dict[str, Any]]],
56+
sampling_config: SamplingParams,
57+
model_name: str,
58+
):
59+
'''
60+
Compare the outputs of a original LLM and a speculative LLM
61+
should be the same when using mtp speculative decoding.
62+
'''
63+
with monkeypatch.context() as m:
64+
m.setenv("VLLM_USE_V1", "1")
65+
66+
ref_llm = LLM(model=model_name, max_model_len=256)
67+
ref_outputs = ref_llm.chat(test_prompts,
68+
sampling_config,
69+
enforce_eager=True)
70+
del ref_llm
71+
72+
spec_llm = LLM(model=model_name,
73+
trust_remote_code=True,
74+
speculative_config={
75+
"method": "deepseek_mtp",
76+
"num_speculative_tokens": 1,
77+
},
78+
max_model_len=256,
79+
enforce_eager=True)
80+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
81+
matches = 0
82+
misses = 0
83+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
84+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
85+
matches += 1
86+
else:
87+
misses += 1
88+
print(f"ref_output: {ref_output.outputs[0].text}")
89+
print(f"spec_output: {spec_output.outputs[0].text}")
90+
91+
# Heuristic: expect at least 66% of the prompts to match exactly
92+
# Upon failure, inspect the outputs to check for inaccuracy.
93+
assert matches > int(0.66 * len(ref_outputs))
94+
del spec_llm

vllm_ascend/attention/mla_v1.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,26 @@
1616

1717
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1818
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
19-
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2019

2120
if TYPE_CHECKING:
2221
from vllm.v1.core.sched.output import SchedulerOutput
2322
from vllm.v1.worker.gpu_input_batch import InputBatch
2423

2524

25+
@dataclass
26+
class CommonAttentionMetadata:
27+
"""
28+
Attention metadata attributes that can be shared by layers in different KV
29+
cache groups and thus having different block table.
30+
"""
31+
32+
query_start_loc: torch.Tensor
33+
"""(batch_size + 1,), the start location of each request in query Tensor"""
34+
seq_lens: torch.Tensor
35+
"""(batch_size,), the length of each request including both computed tokens
36+
and newly scheduled tokens"""
37+
38+
2639
class AscendMLABackend(AttentionBackend):
2740

2841
accept_output_buffer: bool = True
@@ -57,6 +70,7 @@ class AscendMLAPrefillMetadata:
5770
seq_lens: list[int]
5871
context_lens: torch.Tensor
5972
input_positions: torch.Tensor
73+
query_start_loc: torch.Tensor
6074
block_table: torch.Tensor
6175
max_query_len: int
6276
max_seq_lens: int
@@ -90,6 +104,9 @@ class AscendMLAMetadata:
90104

91105
num_actual_tokens: int # Number of tokens excluding padding.
92106
slot_mapping: torch.Tensor
107+
query_start_loc: torch.Tensor
108+
seq_lens: torch.Tensor
109+
block_tables: torch.Tensor
93110

94111
# New for MLA (compared to FlashAttention)
95112
# For handling prefill decode split
@@ -130,7 +147,7 @@ class AscendMLAMetadataBuilder:
130147

131148
# _attn_mask_builder = None
132149
def __init__(self,
133-
runner: "NPUModelRunner",
150+
runner,
134151
metadata_cls: Optional[AscendMLAMetadata] = None):
135152
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
136153
if metadata_cls is not None else AscendMLAMetadata # type: ignore
@@ -230,6 +247,7 @@ def build(self,
230247
num_reqs: int,
231248
num_actual_tokens: int,
232249
max_query_len: int,
250+
common_attn_metadata: CommonAttentionMetadata,
233251
common_prefix_len: Optional[int] = None,
234252
graph_pad_size: int = -1) -> AscendMLAMetadata:
235253
assert self._num_decodes + self._num_prefills == num_reqs
@@ -239,10 +257,8 @@ def build(self,
239257
# it blocks on all previous kernels.
240258
device = self.runner.device
241259

242-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
243-
)
244-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
245-
block_table[:num_reqs])
260+
block_table = (self.runner.input_batch.block_table[0].
261+
get_device_tensor()[:num_reqs])
246262
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
247263
device, non_blocking=True)
248264
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
@@ -254,13 +270,17 @@ def build(self,
254270
seq_lens = seq_lens_cpu
255271
max_query_len = query_lens.max().item()
256272
max_seq_lens = seq_lens.max().item()
273+
query_start_loc = None
257274

258275
prefill_metadata = None
259276
if self._num_prefills > 0:
260277
reqs_start = self._num_decodes # prefill_start
261278
tokens_start = self._num_decode_tokens
262279
max_query_len = query_lens[tokens_start:].max().item()
263280
max_seq_lens = seq_lens[tokens_start:].max().item()
281+
query_start_loc = common_attn_metadata.query_start_loc
282+
prefill_query_start_loc = query_start_loc[
283+
reqs_start:] - query_start_loc[reqs_start]
264284

265285
prefill_metadata = AscendMLAPrefillMetadata(
266286
attn_mask=self.runner.attn_mask,
@@ -271,6 +291,7 @@ def build(self,
271291
block_table=block_table[reqs_start:, ...],
272292
max_query_len=max_query_len,
273293
max_seq_lens=max_seq_lens,
294+
query_start_loc=prefill_query_start_loc,
274295
)
275296

276297
decode_metadata = None
@@ -327,6 +348,9 @@ def build(self,
327348
attn_state=self.runner.attn_state,
328349
prefill=prefill_metadata,
329350
decode=decode_metadata,
351+
query_start_loc=query_start_loc,
352+
block_tables=block_table,
353+
seq_lens=seq_lens,
330354
)
331355

332356

vllm_ascend/ops/attention.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ def vanilla_chunked_prefill_mla(
222222
device="npu",
223223
dtype=value.dtype,
224224
)
225+
num_query = torch.sum(q_mask).item()
226+
num_add_query = num_query - query.size(0)
227+
# mtp will come in
228+
if num_add_query > 0:
229+
add_query_size = query.size()
230+
add_query_size = list(add_query_size)
231+
add_query_size[0] = num_add_query
232+
pad_tensor = torch.zeros(add_query_size,
233+
dtype=query.dtype,
234+
device=query.device)
235+
query = torch.cat([query, pad_tensor], dim=0)
225236
pad_q[q_mask] = query
226237
pad_k[kv_c_mask] = key[kv_c_mask]
227238
pad_v[kv_c_mask] = value[kv_c_mask]
@@ -247,8 +258,8 @@ def vanilla_chunked_prefill_mla(
247258

248259
attn_output = (attn_output[q_mask].view([-1, num_heads,
249260
v_head_dim]).to(output.dtype))
250-
output = output.view_as(attn_output)
251-
output.copy_(attn_output)
261+
output = output.view([-1, num_heads, v_head_dim])
262+
output.copy_(attn_output[:query.size(0) - num_add_query])
252263
return attn_output
253264

254265

0 commit comments

Comments
 (0)