16
16
17
17
from vllm_ascend .attention .attention_v1 import AscendAttentionState
18
18
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
19
- from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
20
19
21
20
if TYPE_CHECKING :
22
21
from vllm .v1 .core .sched .output import SchedulerOutput
23
22
from vllm .v1 .worker .gpu_input_batch import InputBatch
24
23
25
24
25
+ @dataclass
26
+ class CommonAttentionMetadata :
27
+ """
28
+ Attention metadata attributes that can be shared by layers in different KV
29
+ cache groups and thus having different block table.
30
+ """
31
+
32
+ query_start_loc : torch .Tensor
33
+ """(batch_size + 1,), the start location of each request in query Tensor"""
34
+ seq_lens : torch .Tensor
35
+ """(batch_size,), the length of each request including both computed tokens
36
+ and newly scheduled tokens"""
37
+
38
+
26
39
class AscendMLABackend (AttentionBackend ):
27
40
28
41
accept_output_buffer : bool = True
@@ -57,6 +70,7 @@ class AscendMLAPrefillMetadata:
57
70
seq_lens : list [int ]
58
71
context_lens : torch .Tensor
59
72
input_positions : torch .Tensor
73
+ query_start_loc : torch .Tensor
60
74
block_table : torch .Tensor
61
75
max_query_len : int
62
76
max_seq_lens : int
@@ -90,6 +104,9 @@ class AscendMLAMetadata:
90
104
91
105
num_actual_tokens : int # Number of tokens excluding padding.
92
106
slot_mapping : torch .Tensor
107
+ query_start_loc : torch .Tensor
108
+ seq_lens : torch .Tensor
109
+ block_tables : torch .Tensor
93
110
94
111
# New for MLA (compared to FlashAttention)
95
112
# For handling prefill decode split
@@ -130,7 +147,7 @@ class AscendMLAMetadataBuilder:
130
147
131
148
# _attn_mask_builder = None
132
149
def __init__ (self ,
133
- runner : "NPUModelRunner" ,
150
+ runner ,
134
151
metadata_cls : Optional [AscendMLAMetadata ] = None ):
135
152
self .metadata_cls : Optional [AscendMLAMetadata ] = metadata_cls \
136
153
if metadata_cls is not None else AscendMLAMetadata # type: ignore
@@ -230,6 +247,7 @@ def build(self,
230
247
num_reqs : int ,
231
248
num_actual_tokens : int ,
232
249
max_query_len : int ,
250
+ common_attn_metadata : CommonAttentionMetadata ,
233
251
common_prefix_len : Optional [int ] = None ,
234
252
graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
235
253
assert self ._num_decodes + self ._num_prefills == num_reqs
@@ -239,10 +257,8 @@ def build(self,
239
257
# it blocks on all previous kernels.
240
258
device = self .runner .device
241
259
242
- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
243
- )
244
- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
245
- block_table [:num_reqs ])
260
+ block_table = (self .runner .input_batch .block_table [0 ].
261
+ get_device_tensor ()[:num_reqs ])
246
262
slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
247
263
device , non_blocking = True )
248
264
input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
@@ -254,13 +270,17 @@ def build(self,
254
270
seq_lens = seq_lens_cpu
255
271
max_query_len = query_lens .max ().item ()
256
272
max_seq_lens = seq_lens .max ().item ()
273
+ query_start_loc = None
257
274
258
275
prefill_metadata = None
259
276
if self ._num_prefills > 0 :
260
277
reqs_start = self ._num_decodes # prefill_start
261
278
tokens_start = self ._num_decode_tokens
262
279
max_query_len = query_lens [tokens_start :].max ().item ()
263
280
max_seq_lens = seq_lens [tokens_start :].max ().item ()
281
+ query_start_loc = common_attn_metadata .query_start_loc
282
+ prefill_query_start_loc = query_start_loc [
283
+ reqs_start :] - query_start_loc [reqs_start ]
264
284
265
285
prefill_metadata = AscendMLAPrefillMetadata (
266
286
attn_mask = self .runner .attn_mask ,
@@ -271,6 +291,7 @@ def build(self,
271
291
block_table = block_table [reqs_start :, ...],
272
292
max_query_len = max_query_len ,
273
293
max_seq_lens = max_seq_lens ,
294
+ query_start_loc = prefill_query_start_loc ,
274
295
)
275
296
276
297
decode_metadata = None
@@ -327,6 +348,9 @@ def build(self,
327
348
attn_state = self .runner .attn_state ,
328
349
prefill = prefill_metadata ,
329
350
decode = decode_metadata ,
351
+ query_start_loc = query_start_loc ,
352
+ block_tables = block_table ,
353
+ seq_lens = seq_lens ,
330
354
)
331
355
332
356
0 commit comments