@@ -85,6 +85,7 @@ class AscendMLADecodeMetadata:
85
85
seq_lens : torch .Tensor
86
86
max_seq_lens : int
87
87
seq_lens_list : list [int ]
88
+ attn_mask : torch .Tensor
88
89
89
90
90
91
@dataclass
@@ -170,11 +171,12 @@ def reorder_batch(self, input_batch: "InputBatch",
170
171
171
172
for i , req_id in enumerate (input_batch .req_ids ):
172
173
num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
174
+ num_spec_tokens = len (scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
173
175
# for now treat 1 scheduled token as "decode" even if its not,
174
176
# we should update this to something like < 8 in the future but
175
177
# currently the TritonMLA._forward_decode only supports
176
178
# num_tokens = 1
177
- if num_tokens == 1 :
179
+ if num_tokens - num_spec_tokens == 1 :
178
180
decodes .append (i )
179
181
num_decode_tokens += num_tokens
180
182
else :
@@ -335,7 +337,8 @@ def build(self,
335
337
block_table = block_table ,
336
338
seq_lens = seq_lens ,
337
339
seq_lens_list = seq_lens .tolist (),
338
- max_seq_lens = max_seq_lens )
340
+ max_seq_lens = max_seq_lens ,
341
+ attn_mask = self .runner .spec_attn_mask )
339
342
340
343
return self .metadata_cls ( # type: ignore
341
344
num_actual_tokens = num_actual_tokens ,
@@ -424,6 +427,17 @@ def __init__(
424
427
425
428
self .enable_graph_mode = False
426
429
additional_config = get_current_vllm_config ().additional_config
430
+ speculative_config = get_current_vllm_config ().speculative_config
431
+ self .fia_sparse_mode = 0
432
+ self .use_spec_decode = False
433
+ # We need to set the sparse_mode of fused_infer_attention op to 3
434
+ # in spec decoding scenario in order to pass in attention mask.
435
+ if speculative_config is not None :
436
+ self .fia_sparse_mode = 3
437
+ self .use_spec_decode = True
438
+ self .spec_token_num = speculative_config .num_speculative_tokens
439
+ assert self .spec_token_num > 0
440
+
427
441
if additional_config :
428
442
self .enable_graph_mode = additional_config .get (
429
443
"enable_graph_mode" , False )
@@ -628,9 +642,32 @@ def _forward_decode(
628
642
dtype = q .dtype ,
629
643
device = q .device )
630
644
if self .running_in_graph :
631
- # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
632
- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
633
- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
645
+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
646
+ if self .use_spec_decode :
647
+ assert num_tokens % self .spec_token_num == 0
648
+ q_nope = (
649
+ q_nope .view (
650
+ num_tokens // self .spec_token_num ,
651
+ self .spec_token_num ,
652
+ self .num_heads ,
653
+ - 1 ,
654
+ )
655
+ .transpose (1 , 2 )
656
+ .contiguous ()
657
+ )
658
+ q_pe = (
659
+ q_pe .view (
660
+ num_tokens // self .spec_token_num ,
661
+ self .spec_token_num ,
662
+ self .num_heads ,
663
+ - 1 ,
664
+ )
665
+ .transpose (1 , 2 )
666
+ .contiguous ()
667
+ )
668
+ else :
669
+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
670
+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
634
671
# shape of knope/k_pe for npu graph mode should be:
635
672
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
636
673
block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -648,7 +685,8 @@ def _forward_decode(
648
685
num_heads = self .num_heads ,
649
686
num_key_value_heads = self .num_kv_heads ,
650
687
input_layout = "BNSD" ,
651
- atten_mask = attn_metadata .attn_mask ,
688
+ atten_mask = attn_metadata .decode .attn_mask , # type:ignore
689
+ sparse_mode = self .fia_sparse_mode ,
652
690
scale = self .scale ,
653
691
antiquant_mode = 0 ,
654
692
antiquant_scale = None ,
0 commit comments