Skip to content

Commit f3b577c

Browse files
committed
adapt mtp with graph mode in v1
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 3442fbd commit f3b577c

File tree

3 files changed

+55
-9
lines changed

3 files changed

+55
-9
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class AscendMLADecodeMetadata:
8585
seq_lens: torch.Tensor
8686
max_seq_lens: int
8787
seq_lens_list: list[int]
88+
attn_mask: torch.Tensor
8889

8990

9091
@dataclass
@@ -170,11 +171,12 @@ def reorder_batch(self, input_batch: "InputBatch",
170171

171172
for i, req_id in enumerate(input_batch.req_ids):
172173
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
174+
num_spec_tokens = len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
173175
# for now treat 1 scheduled token as "decode" even if its not,
174176
# we should update this to something like < 8 in the future but
175177
# currently the TritonMLA._forward_decode only supports
176178
# num_tokens = 1
177-
if num_tokens == 1:
179+
if num_tokens - num_spec_tokens == 1:
178180
decodes.append(i)
179181
num_decode_tokens += num_tokens
180182
else:
@@ -335,7 +337,8 @@ def build(self,
335337
block_table=block_table,
336338
seq_lens=seq_lens,
337339
seq_lens_list=seq_lens.tolist(),
338-
max_seq_lens=max_seq_lens)
340+
max_seq_lens=max_seq_lens,
341+
attn_mask=self.runner.spec_attn_mask)
339342

340343
return self.metadata_cls( # type: ignore
341344
num_actual_tokens=num_actual_tokens,
@@ -424,6 +427,17 @@ def __init__(
424427

425428
self.enable_graph_mode = False
426429
additional_config = get_current_vllm_config().additional_config
430+
speculative_config = get_current_vllm_config().speculative_config
431+
self.fia_sparse_mode = 0
432+
self.use_spec_decode = False
433+
# We need to set the sparse_mode of fused_infer_attention op to 3
434+
# in spec decoding scenario in order to pass in attention mask.
435+
if speculative_config is not None:
436+
self.fia_sparse_mode = 3
437+
self.use_spec_decode = True
438+
self.spec_token_num = speculative_config.num_speculative_tokens
439+
assert self.spec_token_num > 0
440+
427441
if additional_config:
428442
self.enable_graph_mode = additional_config.get(
429443
"enable_graph_mode", False)
@@ -628,9 +642,32 @@ def _forward_decode(
628642
dtype=q.dtype,
629643
device=q.device)
630644
if self.running_in_graph:
631-
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
632-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
633-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
645+
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
646+
if self.use_spec_decode:
647+
assert num_tokens % self.spec_token_num == 0
648+
q_nope = (
649+
q_nope.view(
650+
num_tokens // self.spec_token_num,
651+
self.spec_token_num,
652+
self.num_heads,
653+
-1,
654+
)
655+
.transpose(1, 2)
656+
.contiguous()
657+
)
658+
q_pe = (
659+
q_pe.view(
660+
num_tokens // self.spec_token_num,
661+
self.spec_token_num,
662+
self.num_heads,
663+
-1,
664+
)
665+
.transpose(1, 2)
666+
.contiguous()
667+
)
668+
else:
669+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
670+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
634671
# shape of knope/k_pe for npu graph mode should be:
635672
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
636673
block_size = kv_c_and_k_pe_cache[0].shape[1]
@@ -648,7 +685,8 @@ def _forward_decode(
648685
num_heads=self.num_heads,
649686
num_key_value_heads=self.num_kv_heads,
650687
input_layout="BNSD",
651-
atten_mask=attn_metadata.attn_mask,
688+
atten_mask=attn_metadata.decode.attn_mask, # type:ignore
689+
sparse_mode=self.fia_sparse_mode,
652690
scale=self.scale,
653691
antiquant_mode=0,
654692
antiquant_scale=None,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
195195

196196
# Set up speculative decoding.
197197
self.use_spec_decode = False
198+
self.spec_attn_mask = None
198199
if self.speculative_config:
199200
self.use_spec_decode = True
201+
# TODO: Need to find out the right value of spec_attn_mask to make sure
202+
# that accuracy is right.
203+
self.spec_attn_mask = torch.zeros(2048, 2048, dtype=torch.bool).to("npu")
200204
if get_pp_group().is_last_rank:
201205
if self.speculative_config.method == "ngram":
202206
self.drafter = NgramProposer(self.vllm_config)
@@ -534,10 +538,13 @@ def _process_reqs(
534538
# Get the number of scheduled tokens for each request.
535539
# TODO: The Python loop can be slow. Optimize.
536540
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
541+
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
537542
max_num_scheduled_tokens = 0
538543
for i, req_id in enumerate(self.input_batch.req_ids):
539544
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
540545
num_scheduled_tokens[i] = num_tokens
546+
num_valid_tokens[i] = num_tokens - \
547+
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
541548
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
542549
num_tokens)
543550

@@ -584,7 +591,7 @@ def _process_reqs(
584591
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
585592
attn_state = AscendAttentionState.PrefillNoCache
586593
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
587-
elif np.all(num_scheduled_tokens == 1):
594+
elif np.all(num_valid_tokens == 1):
588595
attn_state = AscendAttentionState.DecodeOnly
589596
# splitfuse
590597
elif not self.use_v0_scheduler or self.chunked_prefill_enabled:
@@ -618,7 +625,7 @@ def _process_reqs(
618625
query_start_loc=query_start_loc, seq_lens=seq_lens)
619626
# Add graph_pad_size here
620627
if self.enable_torchair_graph_mode:
621-
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
628+
graph_pad_size = self.scheduler_config.max_num_seqs - sum(num_scheduled_tokens)
622629
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
623630

624631
if self.vllm_config.model_config.use_mla:

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
set_current_vllm_config)
55
from vllm.forward_context import set_forward_context
66
from vllm.model_executor.model_loader import get_model_loader
7-
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
7+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype, process_weights_after_loading
88
from vllm.v1.sample.metadata import SamplingMetadata
99

1010
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
@@ -199,6 +199,7 @@ def load_model(self) -> None:
199199
loader.get_all_weights(
200200
self.vllm_config.speculative_config.draft_model_config,
201201
self.model))
202+
process_weights_after_loading(self.model, draft_model_config, target_device)
202203

203204

204205
# TODO Using torch instead of triton may result in poor performance

0 commit comments

Comments
 (0)