@@ -260,6 +260,8 @@ class AscendMetadata(AttentionMetadata):
260
260
# requests only.
261
261
max_decode_seq_len : int
262
262
263
+ chunked_prefill_enabled : bool
264
+
263
265
# (batch_size, max_blocks_per_seq).
264
266
# Block addresses per sequence. (Seq id -> list of physical block)
265
267
block_tables : Optional [torch .Tensor ]
@@ -271,6 +273,9 @@ class AscendMetadata(AttentionMetadata):
271
273
# the computed tokens + new tokens None if it is a decoding.
272
274
seq_lens : Optional [List [int ]] = None
273
275
276
+ # The query lengths of the input sequences
277
+ query_lens : Optional [List [int ]] = None
278
+
274
279
# Maximum query length in the batch. None for decoding.
275
280
max_query_len : Optional [int ] = None
276
281
@@ -290,8 +295,15 @@ class AscendMetadata(AttentionMetadata):
290
295
# Number of tokens input to encoder
291
296
num_encoder_tokens : Optional [int ] = None
292
297
298
+ # Mask for normal situation
293
299
attn_mask : Optional [torch .Tensor ] = None
294
300
301
+ # Mask for prefix caching
302
+ compress_mask : Optional [torch .Tensor ] = None
303
+
304
+ # Mask for chunked prefill
305
+ chunk_mask : Optional [torch .Tensor ] = None
306
+
295
307
# Cross-attention memory-mapping data structures: slot mapping
296
308
# and block tables
297
309
cross_slot_mapping : Optional [torch .Tensor ] = None
@@ -315,6 +327,8 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]:
315
327
self .slot_mapping [:self .num_prefill_tokens ])
316
328
seq_lens = (None if self .seq_lens is None else
317
329
self .seq_lens [:self .num_prefills ])
330
+ query_lens = (None if self .query_lens is None else
331
+ self .query_lens [:self .num_prefills ])
318
332
block_tables = (None if self .block_tables is None else
319
333
self .block_tables [:self .num_prefills ])
320
334
@@ -329,9 +343,11 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]:
329
343
slot_mapping = slot_mapping ,
330
344
seq_lens = seq_lens ,
331
345
seq_lens_tensor = seq_lens_tensor ,
346
+ query_lens = query_lens ,
332
347
max_query_len = self .max_query_len ,
333
348
max_prefill_seq_len = self .max_prefill_seq_len ,
334
349
max_decode_seq_len = 0 ,
350
+ chunked_prefill_enabled = self .chunked_prefill_enabled ,
335
351
block_tables = block_tables ,
336
352
# Begin encoder & cross attn fields below...
337
353
encoder_seq_lens = self .encoder_seq_lens ,
@@ -359,6 +375,8 @@ def decode_metadata(self) -> Optional["AscendMetadata"]:
359
375
self .slot_mapping [self .num_prefill_tokens :])
360
376
seq_lens = (None if self .seq_lens is None else
361
377
self .seq_lens [self .num_prefills :])
378
+ query_lens = (None if self .query_lens is None else
379
+ self .query_lens [self .num_prefills :])
362
380
block_tables = (None if self .block_tables is None else
363
381
self .block_tables [self .num_prefills :])
364
382
seq_lens_tensor = (None if self .seq_lens_tensor is None else
@@ -371,9 +389,11 @@ def decode_metadata(self) -> Optional["AscendMetadata"]:
371
389
slot_mapping = slot_mapping ,
372
390
seq_lens = seq_lens ,
373
391
seq_lens_tensor = seq_lens_tensor ,
392
+ query_lens = query_lens ,
374
393
max_query_len = self .max_query_len ,
375
394
max_prefill_seq_len = 0 ,
376
395
max_decode_seq_len = self .max_decode_seq_len ,
396
+ chunked_prefill_enabled = self .chunked_prefill_enabled ,
377
397
block_tables = block_tables ,
378
398
# Begin encoder & cross attn fields below...
379
399
encoder_seq_lens = self .encoder_seq_lens ,
@@ -482,6 +502,8 @@ def __init__(self, input_builder: "ModelInputForNPUBuilder"):
482
502
self .block_size = input_builder .block_size
483
503
484
504
self .attn_mask = None
505
+ self .compress_mask = None
506
+ self .chunk_mask = None
485
507
if AscendMetadataBuilder ._attn_mask_builder is None :
486
508
AscendMetadataBuilder ._attn_mask_builder = AttentionMaskBuilder .initialize_from_len (
487
509
128 , self .input_builder .runner .model_config .dtype )
@@ -590,11 +612,13 @@ def build(
590
612
self .input_builder .chunked_prefill_enabled )
591
613
592
614
device = self .runner .device
615
+ dtype = self .runner .model_config .dtype
593
616
use_npu_graph = graph_pad_size != - 1
594
617
595
618
max_query_len = max (query_lens )
596
619
max_prefill_seq_len = max (self .prefill_seq_lens , default = 0 )
597
620
max_decode_seq_len = max (self .curr_seq_lens , default = 0 )
621
+ max_seq_len = max (max_prefill_seq_len , max_decode_seq_len )
598
622
num_decode_tokens = self .num_decode_tokens
599
623
600
624
if self .num_prefills == 0 and use_npu_graph :
@@ -612,12 +636,29 @@ def build(
612
636
)
613
637
614
638
if self .num_prefills > 0 :
615
- self .attn_mask = AscendMetadataBuilder ._attn_mask_builder .get_attn_mask ( # type: ignore
616
- max_prefill_seq_len ,
617
- self .input_builder .runner .model_config .dtype ,
618
- self .input_builder .runner .device )
639
+ if block_tables is None or block_tables .numel () == 0 :
640
+ # normal mask
641
+ self .attn_mask = AscendMetadataBuilder ._attn_mask_builder .get_attn_mask ( # type: ignore
642
+ max_prefill_seq_len , dtype , device )
643
+ elif self .num_decode_tokens == 0 and not self .input_builder .chunked_prefill_enabled :
644
+ # compress mask for prefix cache
645
+ self .compress_mask = AscendMetadataBuilder ._attn_mask_builder .get_attn_mask ( # type: ignore
646
+ 128 , dtype , device )
647
+ else :
648
+ # chunk_mask for chunk prefill
649
+ attn_mask = AscendMetadataBuilder ._attn_mask_builder .get_attn_mask ( # type: ignore
650
+ max_seq_len , dtype , device )
651
+ if attn_mask .numel () > 1 and attn_mask [0 ][1 ] > 0 :
652
+ attn_mask *= - 10000
653
+ chunk_mask_list = []
654
+ for i , seq_len in enumerate (seq_lens ):
655
+ context_len = self .context_lens [i ]
656
+ chunk_mask_list .append (attn_mask [context_len :seq_len ])
657
+ self .chunk_mask = torch .cat (chunk_mask_list , 0 )
619
658
else :
620
659
self .attn_mask = None
660
+ self .compress_mask = None
661
+ self .chunk_mask = None
621
662
622
663
assert max_query_len > 0 , "query_lens: {}" .format (query_lens )
623
664
@@ -641,11 +682,15 @@ def build(
641
682
multi_modal_placeholder_index_maps = placeholder_index_maps ,
642
683
enable_kv_scales_calculation = True ,
643
684
seq_lens_tensor = seq_lens_tensor ,
685
+ query_lens = query_lens ,
644
686
max_query_len = max_query_len ,
645
687
max_prefill_seq_len = max_prefill_seq_len ,
646
688
max_decode_seq_len = max_decode_seq_len ,
647
689
block_tables = block_tables ,
648
690
attn_mask = self .attn_mask ,
691
+ compress_mask = self .compress_mask ,
692
+ chunk_mask = self .chunk_mask ,
693
+ chunked_prefill_enabled = self .input_builder .chunked_prefill_enabled ,
649
694
)
650
695
651
696
@@ -681,6 +726,7 @@ def __init__(
681
726
assert self .num_heads % self .num_kv_heads == 0
682
727
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
683
728
self .seq_len_cpu_tensor = None
729
+ self .query_len_cpu_tensor = None
684
730
self .key_cache = None
685
731
self .value_cache = None
686
732
@@ -769,7 +815,7 @@ def forward(
769
815
slot_indices = slots )
770
816
771
817
if attn_metadata .num_prefills > 0 :
772
-
818
+ # Prefix cache disabled and chunk prefill disabled or no prefix cache hit
773
819
if (attn_metadata .block_tables is None
774
820
or attn_metadata .block_tables .numel () == 0 ):
775
821
if attn_type == AttentionType .ENCODER_ONLY :
@@ -816,13 +862,60 @@ def forward(
816
862
num_heads = self .num_heads ,
817
863
num_kv_heads = self .num_kv_heads ,
818
864
out = output )
865
+ # Prefix cache only and cache hit
866
+ elif attn_metadata .num_decode_tokens == 0 and not attn_metadata .chunked_prefill_enabled :
867
+ assert kv_cache is not None
868
+ assert attn_metadata .prefill_metadata is not None
869
+ self .seq_lens_tensor_cpu = torch .from_numpy (
870
+ np .array (
871
+ attn_metadata .prefill_metadata .seq_lens ).astype (
872
+ np .int32 ))
873
+ self .query_lens_tensor_cpu = torch .from_numpy (
874
+ np .array (
875
+ attn_metadata .prefill_metadata .query_lens ).astype (
876
+ np .int32 ))
877
+ block_tables = attn_metadata .prefill_metadata .block_tables
878
+ assert attn_metadata .compress_mask is not None
879
+ compress_mask = attn_metadata .compress_mask
880
+ torch_npu ._npu_flash_attention_qlens (
881
+ query = query ,
882
+ key_cache = self .key_cache ,
883
+ value_cache = self .value_cache ,
884
+ block_table = block_tables ,
885
+ mask = compress_mask ,
886
+ seq_len = self .query_lens_tensor_cpu ,
887
+ context_lens = self .seq_lens_tensor_cpu ,
888
+ num_kv_heads = self .num_kv_heads ,
889
+ num_heads = self .num_heads ,
890
+ scale_value = self .scale ,
891
+ out = output )
892
+ # Splitfuse
819
893
else :
820
- # TODO: Will support prefix cache and chunked prefill soon.
821
- raise RuntimeError (
822
- "Prefix cache and chunked prefill are currently not supported."
823
- )
824
- elif attn_metadata .decode_metadata :
894
+ assert kv_cache is not None
895
+ self .seq_lens_tensor_cpu = torch .from_numpy (
896
+ np .array (attn_metadata .seq_lens ).astype (np .int32 ))
897
+ self .query_lens_tensor_cpu = torch .from_numpy (
898
+ np .array (attn_metadata .query_lens ).astype (np .int32 ))
899
+ block_tables = attn_metadata .block_tables
900
+ assert attn_metadata .chunk_mask is not None
901
+ chunk_mask = attn_metadata .chunk_mask
902
+ torch_npu ._npu_paged_attention_splitfuse (
903
+ query = query ,
904
+ key_cache = self .key_cache ,
905
+ value_cache = self .value_cache ,
906
+ block_table = block_tables ,
907
+ context_lens = self .seq_lens_tensor_cpu ,
908
+ mask = chunk_mask ,
909
+ seq_len = self .query_lens_tensor_cpu ,
910
+ num_kv_heads = self .num_kv_heads ,
911
+ num_heads = self .num_heads ,
912
+ scale_value = self .scale ,
913
+ out = output )
914
+ # Decode only
915
+ else :
825
916
assert self .key_cache is not None
917
+ assert self .value_cache is not None
918
+ assert attn_metadata .decode_metadata is not None
826
919
self .seq_lens_tensor_cpu = torch .from_numpy (
827
920
np .array (attn_metadata .decode_metadata .seq_lens ).astype (
828
921
np .int32 ))
0 commit comments