Skip to content

Commit fa99f89

Browse files
authored
[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)
### What this PR does / why we need it? Support the features of prefix cache and chunked prefill in v0/v1. --------- Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 324f819 commit fa99f89

File tree

6 files changed

+156
-32
lines changed

6 files changed

+156
-32
lines changed

vllm_ascend/attention/attention.py

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ class AscendMetadata(AttentionMetadata):
260260
# requests only.
261261
max_decode_seq_len: int
262262

263+
chunked_prefill_enabled: bool
264+
263265
# (batch_size, max_blocks_per_seq).
264266
# Block addresses per sequence. (Seq id -> list of physical block)
265267
block_tables: Optional[torch.Tensor]
@@ -271,6 +273,9 @@ class AscendMetadata(AttentionMetadata):
271273
# the computed tokens + new tokens None if it is a decoding.
272274
seq_lens: Optional[List[int]] = None
273275

276+
# The query lengths of the input sequences
277+
query_lens: Optional[List[int]] = None
278+
274279
# Maximum query length in the batch. None for decoding.
275280
max_query_len: Optional[int] = None
276281

@@ -290,8 +295,15 @@ class AscendMetadata(AttentionMetadata):
290295
# Number of tokens input to encoder
291296
num_encoder_tokens: Optional[int] = None
292297

298+
# Mask for normal situation
293299
attn_mask: Optional[torch.Tensor] = None
294300

301+
# Mask for prefix caching
302+
compress_mask: Optional[torch.Tensor] = None
303+
304+
# Mask for chunked prefill
305+
chunk_mask: Optional[torch.Tensor] = None
306+
295307
# Cross-attention memory-mapping data structures: slot mapping
296308
# and block tables
297309
cross_slot_mapping: Optional[torch.Tensor] = None
@@ -315,6 +327,8 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]:
315327
self.slot_mapping[:self.num_prefill_tokens])
316328
seq_lens = (None if self.seq_lens is None else
317329
self.seq_lens[:self.num_prefills])
330+
query_lens = (None if self.query_lens is None else
331+
self.query_lens[:self.num_prefills])
318332
block_tables = (None if self.block_tables is None else
319333
self.block_tables[:self.num_prefills])
320334

@@ -329,9 +343,11 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]:
329343
slot_mapping=slot_mapping,
330344
seq_lens=seq_lens,
331345
seq_lens_tensor=seq_lens_tensor,
346+
query_lens=query_lens,
332347
max_query_len=self.max_query_len,
333348
max_prefill_seq_len=self.max_prefill_seq_len,
334349
max_decode_seq_len=0,
350+
chunked_prefill_enabled=self.chunked_prefill_enabled,
335351
block_tables=block_tables,
336352
# Begin encoder & cross attn fields below...
337353
encoder_seq_lens=self.encoder_seq_lens,
@@ -359,6 +375,8 @@ def decode_metadata(self) -> Optional["AscendMetadata"]:
359375
self.slot_mapping[self.num_prefill_tokens:])
360376
seq_lens = (None if self.seq_lens is None else
361377
self.seq_lens[self.num_prefills:])
378+
query_lens = (None if self.query_lens is None else
379+
self.query_lens[self.num_prefills:])
362380
block_tables = (None if self.block_tables is None else
363381
self.block_tables[self.num_prefills:])
364382
seq_lens_tensor = (None if self.seq_lens_tensor is None else
@@ -371,9 +389,11 @@ def decode_metadata(self) -> Optional["AscendMetadata"]:
371389
slot_mapping=slot_mapping,
372390
seq_lens=seq_lens,
373391
seq_lens_tensor=seq_lens_tensor,
392+
query_lens=query_lens,
374393
max_query_len=self.max_query_len,
375394
max_prefill_seq_len=0,
376395
max_decode_seq_len=self.max_decode_seq_len,
396+
chunked_prefill_enabled=self.chunked_prefill_enabled,
377397
block_tables=block_tables,
378398
# Begin encoder & cross attn fields below...
379399
encoder_seq_lens=self.encoder_seq_lens,
@@ -482,6 +502,8 @@ def __init__(self, input_builder: "ModelInputForNPUBuilder"):
482502
self.block_size = input_builder.block_size
483503

484504
self.attn_mask = None
505+
self.compress_mask = None
506+
self.chunk_mask = None
485507
if AscendMetadataBuilder._attn_mask_builder is None:
486508
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
487509
128, self.input_builder.runner.model_config.dtype)
@@ -590,11 +612,13 @@ def build(
590612
self.input_builder.chunked_prefill_enabled)
591613

592614
device = self.runner.device
615+
dtype = self.runner.model_config.dtype
593616
use_npu_graph = graph_pad_size != -1
594617

595618
max_query_len = max(query_lens)
596619
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
597620
max_decode_seq_len = max(self.curr_seq_lens, default=0)
621+
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
598622
num_decode_tokens = self.num_decode_tokens
599623

600624
if self.num_prefills == 0 and use_npu_graph:
@@ -612,12 +636,29 @@ def build(
612636
)
613637

614638
if self.num_prefills > 0:
615-
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
616-
max_prefill_seq_len,
617-
self.input_builder.runner.model_config.dtype,
618-
self.input_builder.runner.device)
639+
if block_tables is None or block_tables.numel() == 0:
640+
# normal mask
641+
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
642+
max_prefill_seq_len, dtype, device)
643+
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
644+
# compress mask for prefix cache
645+
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
646+
128, dtype, device)
647+
else:
648+
# chunk_mask for chunk prefill
649+
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
650+
max_seq_len, dtype, device)
651+
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
652+
attn_mask *= -10000
653+
chunk_mask_list = []
654+
for i, seq_len in enumerate(seq_lens):
655+
context_len = self.context_lens[i]
656+
chunk_mask_list.append(attn_mask[context_len:seq_len])
657+
self.chunk_mask = torch.cat(chunk_mask_list, 0)
619658
else:
620659
self.attn_mask = None
660+
self.compress_mask = None
661+
self.chunk_mask = None
621662

622663
assert max_query_len > 0, "query_lens: {}".format(query_lens)
623664

@@ -641,11 +682,15 @@ def build(
641682
multi_modal_placeholder_index_maps=placeholder_index_maps,
642683
enable_kv_scales_calculation=True,
643684
seq_lens_tensor=seq_lens_tensor,
685+
query_lens=query_lens,
644686
max_query_len=max_query_len,
645687
max_prefill_seq_len=max_prefill_seq_len,
646688
max_decode_seq_len=max_decode_seq_len,
647689
block_tables=block_tables,
648690
attn_mask=self.attn_mask,
691+
compress_mask=self.compress_mask,
692+
chunk_mask=self.chunk_mask,
693+
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
649694
)
650695

651696

@@ -681,6 +726,7 @@ def __init__(
681726
assert self.num_heads % self.num_kv_heads == 0
682727
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
683728
self.seq_len_cpu_tensor = None
729+
self.query_len_cpu_tensor = None
684730
self.key_cache = None
685731
self.value_cache = None
686732

@@ -769,7 +815,7 @@ def forward(
769815
slot_indices=slots)
770816

771817
if attn_metadata.num_prefills > 0:
772-
818+
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
773819
if (attn_metadata.block_tables is None
774820
or attn_metadata.block_tables.numel() == 0):
775821
if attn_type == AttentionType.ENCODER_ONLY:
@@ -816,13 +862,60 @@ def forward(
816862
num_heads=self.num_heads,
817863
num_kv_heads=self.num_kv_heads,
818864
out=output)
865+
# Prefix cache only and cache hit
866+
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
867+
assert kv_cache is not None
868+
assert attn_metadata.prefill_metadata is not None
869+
self.seq_lens_tensor_cpu = torch.from_numpy(
870+
np.array(
871+
attn_metadata.prefill_metadata.seq_lens).astype(
872+
np.int32))
873+
self.query_lens_tensor_cpu = torch.from_numpy(
874+
np.array(
875+
attn_metadata.prefill_metadata.query_lens).astype(
876+
np.int32))
877+
block_tables = attn_metadata.prefill_metadata.block_tables
878+
assert attn_metadata.compress_mask is not None
879+
compress_mask = attn_metadata.compress_mask
880+
torch_npu._npu_flash_attention_qlens(
881+
query=query,
882+
key_cache=self.key_cache,
883+
value_cache=self.value_cache,
884+
block_table=block_tables,
885+
mask=compress_mask,
886+
seq_len=self.query_lens_tensor_cpu,
887+
context_lens=self.seq_lens_tensor_cpu,
888+
num_kv_heads=self.num_kv_heads,
889+
num_heads=self.num_heads,
890+
scale_value=self.scale,
891+
out=output)
892+
# Splitfuse
819893
else:
820-
# TODO: Will support prefix cache and chunked prefill soon.
821-
raise RuntimeError(
822-
"Prefix cache and chunked prefill are currently not supported."
823-
)
824-
elif attn_metadata.decode_metadata:
894+
assert kv_cache is not None
895+
self.seq_lens_tensor_cpu = torch.from_numpy(
896+
np.array(attn_metadata.seq_lens).astype(np.int32))
897+
self.query_lens_tensor_cpu = torch.from_numpy(
898+
np.array(attn_metadata.query_lens).astype(np.int32))
899+
block_tables = attn_metadata.block_tables
900+
assert attn_metadata.chunk_mask is not None
901+
chunk_mask = attn_metadata.chunk_mask
902+
torch_npu._npu_paged_attention_splitfuse(
903+
query=query,
904+
key_cache=self.key_cache,
905+
value_cache=self.value_cache,
906+
block_table=block_tables,
907+
context_lens=self.seq_lens_tensor_cpu,
908+
mask=chunk_mask,
909+
seq_len=self.query_lens_tensor_cpu,
910+
num_kv_heads=self.num_kv_heads,
911+
num_heads=self.num_heads,
912+
scale_value=self.scale,
913+
out=output)
914+
# Decode only
915+
else:
825916
assert self.key_cache is not None
917+
assert self.value_cache is not None
918+
assert attn_metadata.decode_metadata is not None
826919
self.seq_lens_tensor_cpu = torch.from_numpy(
827920
np.array(attn_metadata.decode_metadata.seq_lens).astype(
828921
np.int32))

vllm_ascend/attention/attention_v1.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ def copy_blocks(
9696

9797

9898
class AscendAttentionState(Enum):
99-
PrefillOnly = 0
100-
DecodeOnly = 1
101-
ChunkedPrefill = 2
99+
PrefillNoCache = 0
100+
PrefillCacheHit = 1
101+
DecodeOnly = 2
102+
ChunkedPrefill = 3
102103

103104

104105
@dataclass
@@ -264,7 +265,7 @@ def forward(
264265
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
265266
pass
266267
# V0-Style scheduler situation.
267-
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
268+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
268269
assert attn_metadata is not None
269270
assert attn_metadata.attn_mask is not None
270271
mask = attn_metadata.attn_mask
@@ -277,16 +278,31 @@ def forward(
277278
num_heads=self.num_heads,
278279
num_kv_heads=self.num_kv_heads,
279280
out=output)
281+
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
282+
assert attn_metadata is not None
283+
assert attn_metadata.attn_mask is not None
284+
compress_mask = attn_metadata.attn_mask
285+
torch_npu._npu_flash_attention_qlens(
286+
query=query,
287+
key_cache=self.key_cache,
288+
value_cache=self.value_cache,
289+
block_table=attn_metadata.block_tables,
290+
mask=compress_mask,
291+
seq_len=attn_metadata.query_lens,
292+
context_lens=attn_metadata.seq_lens,
293+
num_kv_heads=self.num_kv_heads,
294+
num_heads=self.num_heads,
295+
scale_value=self.scale,
296+
out=output)
280297
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
281-
block_tables = attn_metadata.block_tables
282298
torch_npu._npu_paged_attention(
283299
query=query,
284300
key_cache=self.key_cache,
285301
value_cache=self.value_cache,
286302
num_kv_heads=self.num_kv_heads,
287303
num_heads=self.num_heads,
288304
scale_value=self.scale,
289-
block_table=block_tables,
305+
block_table=attn_metadata.block_tables,
290306
context_lens=attn_metadata.seq_lens,
291307
out=output)
292308
# Normal V1 situation.

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def _forward_prefill(
417417

418418
num_tokens = query.size(0)
419419
attn_output = None
420-
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
420+
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
421421
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
422422
attn_output = torch.empty(num_tokens,
423423
self.num_heads * self.v_head_dim,
@@ -440,7 +440,7 @@ def _forward_prefill(
440440
scale=self.scale,
441441
alibi_slopes=None,
442442
causal=True)
443-
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
443+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
444444
attn_output = torch.empty(num_tokens,
445445
self.num_heads,
446446
self.padding_head_dim,
@@ -479,7 +479,7 @@ def _forward_prefill(
479479
self.padding_head_dim)[:, :, :self.v_head_dim]
480480
else:
481481
raise RuntimeError(
482-
"Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
482+
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
483483
)
484484
attn_output = attn_output.reshape(
485485
[num_tokens, self.num_heads * self.v_head_dim])

vllm_ascend/platform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
175175
if cache_config:
176176
if cache_config.block_size is None:
177177
cache_config.block_size = 128
178-
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
178+
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
179179
logger.warning(
180-
"Prefix caching is not supported for V1 now, disable prefix caching"
180+
"If prefix caching is enabled, block size must be set to 128."
181181
)
182-
cache_config.enable_prefix_caching = False
182+
cache_config.block_size = 128
183183

184184
if envs.VLLM_USE_V1:
185185
# Activate custom ops for v1.

vllm_ascend/worker/model_runner.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -693,15 +693,23 @@ def _compute_for_prefix_cache_hit(
693693
# this may be larger than the sequence length if chunked
694694
# prefill is enabled.
695695
prefix_cache_len = len(computed_block_nums) * self.block_size
696-
seq_group_metadata.seq_data[inter_data.seq_ids[
697-
seq_idx]].update_num_cached_tokens(prefix_cache_len)
698696

699-
# The number of so far computed prompt tokens in this sequence.
700-
context_len = inter_data.context_lens[seq_idx]
701697
# The total number of prompt tokens in this sequence.
702698
# When chunked prefill is enabled, this is the token number of
703699
# computed chunks + current chunk.
704700
seq_len = inter_data.seq_lens[seq_idx]
701+
702+
# When full hit, compute the last block rather than the last token,
703+
# due to the requirements of prefix operator.
704+
if seq_len <= prefix_cache_len:
705+
prefix_cache_len -= self.block_size
706+
707+
seq_group_metadata.seq_data[inter_data.seq_ids[
708+
seq_idx]].update_num_cached_tokens(prefix_cache_len)
709+
710+
# The number of so far computed prompt tokens in this sequence.
711+
context_len = inter_data.context_lens[seq_idx]
712+
705713
if prefix_cache_len <= context_len:
706714
# We already passed the cache hit region,
707715
# so do normal computation.

0 commit comments

Comments
 (0)