1
1
"""Attention layer with FlashAttention."""
2
2
from dataclasses import dataclass
3
- from typing import Any , Dict , List , Optional , Tuple , Type
3
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type
4
4
5
5
import torch
6
6
from vllm_flash_attn import flash_attn_varlen_func , flash_attn_with_kvcache
7
7
8
8
from vllm import _custom_ops as ops
9
9
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
10
- AttentionMetadata , AttentionType )
10
+ AttentionMetadata ,
11
+ AttentionMetadataBuilder ,
12
+ AttentionType )
13
+ from vllm .attention .backends .utils import (PAD_SLOT_ID , compute_slot_mapping ,
14
+ compute_slot_mapping_start_idx ,
15
+ is_block_tables_empty )
16
+ from vllm .sequence import SequenceGroupMetadata
17
+ from vllm .utils import make_tensor_with_pad
18
+
19
+ if TYPE_CHECKING :
20
+ from vllm .worker .model_runner import (GPUModelRunnerBase ,
21
+ ModelInputForGPUBuilder )
11
22
12
23
13
24
class FlashAttentionBackend (AttentionBackend ):
@@ -28,6 +39,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
28
39
def get_metadata_cls () -> Type ["AttentionMetadata" ]:
29
40
return FlashAttentionMetadata
30
41
42
+ @staticmethod
43
+ def get_builder_cls () -> Type ["FlashAttentionMetadataBuilder" ]:
44
+ return FlashAttentionMetadataBuilder
45
+
31
46
@staticmethod
32
47
def get_kv_cache_shape (
33
48
num_blocks : int ,
@@ -184,6 +199,170 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
184
199
return self ._cached_decode_metadata
185
200
186
201
202
+ class FlashAttentionMetadataBuilder (
203
+ AttentionMetadataBuilder [FlashAttentionMetadata ]):
204
+
205
+ def __init__ (self , input_builder : "ModelInputForGPUBuilder" ):
206
+ self .slot_mapping : List [int ] = []
207
+ self .prefill_seq_lens : List [int ] = []
208
+ self .context_lens : List [int ] = []
209
+ self .block_tables : List [List [int ]] = []
210
+ self .curr_seq_lens : List [int ] = []
211
+ self .num_prefills = 0
212
+ self .num_prefill_tokens = 0
213
+ self .num_decode_tokens = 0
214
+
215
+ self .sliding_window = input_builder .sliding_window
216
+ self .block_size = input_builder .block_size
217
+ self .use_v2_block_manager = (
218
+ input_builder .scheduler_config .use_v2_block_manager )
219
+
220
+ def add_seq_group (self , seq_group_metadata : SequenceGroupMetadata ,
221
+ token_lens : List [int ], seq_lens : List [int ],
222
+ curr_seq_lens : List [int ], query_lens : List [int ],
223
+ context_lens : List [int ],
224
+ curr_sliding_window_blocks : List [int ],
225
+ prefix_cache_hit : bool , chunked_prefill_enabled : bool ):
226
+ """Add a sequence group to the metadata. Specifically update/append
227
+ 1. context length.
228
+ 2. block table.
229
+ 3. slot mapping.
230
+ """
231
+ is_prompt = seq_group_metadata .is_prompt
232
+ block_tables = seq_group_metadata .block_tables
233
+
234
+ for (seq_id , token_len , seq_len , curr_seq_len , query_len , context_len ,
235
+ curr_sliding_window_block ) in zip (
236
+ seq_group_metadata .seq_data .keys (), token_lens , seq_lens ,
237
+ curr_seq_lens , query_lens , context_lens ,
238
+ curr_sliding_window_blocks ):
239
+ self .context_lens .append (context_len )
240
+
241
+ if is_prompt :
242
+ self .num_prefills += 1
243
+ self .num_prefill_tokens += token_len
244
+ self .prefill_seq_lens .append (seq_len )
245
+ else :
246
+ assert query_len == 1 , (
247
+ "seq_len: {}, context_len: {}, query_len: {}" .format (
248
+ seq_len , context_len , query_len ))
249
+ self .num_decode_tokens += query_len
250
+ self .curr_seq_lens .append (curr_seq_len )
251
+
252
+ # Compute block table.
253
+ # TODO(sang): Combine chunked prefill and prefix caching by
254
+ # only allowing multiple of block_size chunk size.
255
+ # NOTE: This only works for oooooooxxx style attention.
256
+ block_table = []
257
+ if prefix_cache_hit :
258
+ # NOTE(woosuk): For flash-attn, the block table should
259
+ # include the entries for the incoming prefill tokens.
260
+ block_table = block_tables [seq_id ]
261
+ elif ((chunked_prefill_enabled or not is_prompt )
262
+ and block_tables is not None ):
263
+ block_table = block_tables [seq_id ][- curr_sliding_window_block :]
264
+ self .block_tables .append (block_table )
265
+
266
+ # Compute slot mapping.
267
+ is_profile_run = is_block_tables_empty (block_tables )
268
+ start_idx = compute_slot_mapping_start_idx (
269
+ is_prompt , query_len , context_len , self .sliding_window ,
270
+ self .use_v2_block_manager )
271
+ compute_slot_mapping (is_profile_run , self .slot_mapping , seq_id ,
272
+ seq_len , context_len , start_idx ,
273
+ self .block_size ,
274
+ seq_group_metadata .block_tables )
275
+
276
+ def build (self , runner : "GPUModelRunnerBase" , seq_lens , query_lens ,
277
+ cuda_graph_pad_size : int , batch_size : int ):
278
+ """Build attention metadata with on-device tensors."""
279
+ device = runner .device
280
+ use_captured_graph = cuda_graph_pad_size != - 1
281
+
282
+ logits_soft_cap = getattr (runner .model_config .hf_config ,
283
+ "attn_logit_softcapping" , None )
284
+ if logits_soft_cap is not None :
285
+ raise ValueError (
286
+ "Please use Flashinfer backend for models with logits_soft_cap"
287
+ " (i.e., Gemma-2). Otherwise, the output might be wrong."
288
+ " Set Flashinfer backend by "
289
+ "export VLLM_ATTENTION_BACKEND=FLASHINFER." )
290
+
291
+ max_query_len = max (query_lens )
292
+ max_prefill_seq_len = max (self .prefill_seq_lens , default = 0 )
293
+ max_decode_seq_len = max (self .curr_seq_lens , default = 0 )
294
+ num_decode_tokens = self .num_decode_tokens
295
+
296
+ if use_captured_graph :
297
+ self .slot_mapping .extend ([PAD_SLOT_ID ] * cuda_graph_pad_size )
298
+ self .block_tables .extend ([] * cuda_graph_pad_size )
299
+ num_decode_tokens = batch_size + cuda_graph_pad_size
300
+
301
+ # The shape of graph_block_tables is
302
+ # [max batch size, max context len // block size].
303
+ input_block_tables = runner .graph_block_tables [:batch_size ]
304
+ for i , block_table in enumerate (self .block_tables ):
305
+ if block_table :
306
+ input_block_tables [i , :len (block_table )] = block_table
307
+ block_tables = torch .tensor (input_block_tables , device = device )
308
+ else :
309
+ max_block_table_len = max (
310
+ len (block_table ) for block_table in self .block_tables )
311
+ block_tables = make_tensor_with_pad (
312
+ self .block_tables ,
313
+ max_len = max_block_table_len ,
314
+ pad = 0 ,
315
+ dtype = torch .int ,
316
+ device = device ,
317
+ )
318
+ assert max_query_len > 0 , ("query_lens: {}" .format (query_lens ))
319
+
320
+ context_lens_tensor = torch .tensor (self .context_lens ,
321
+ dtype = torch .int ,
322
+ device = device )
323
+ seq_lens_tensor = torch .tensor (seq_lens ,
324
+ dtype = torch .int ,
325
+ device = device )
326
+ query_lens_tensor = torch .tensor (query_lens ,
327
+ dtype = torch .long ,
328
+ device = device )
329
+ query_start_loc = torch .zeros (query_lens_tensor .shape [0 ] + 1 ,
330
+ dtype = torch .int32 ,
331
+ device = device )
332
+ seq_start_loc = torch .zeros (seq_lens_tensor .shape [0 ] + 1 ,
333
+ dtype = torch .int32 ,
334
+ device = device )
335
+ torch .cumsum (seq_lens_tensor ,
336
+ dim = 0 ,
337
+ dtype = seq_start_loc .dtype ,
338
+ out = seq_start_loc [1 :])
339
+ torch .cumsum (query_lens_tensor ,
340
+ dim = 0 ,
341
+ dtype = query_start_loc .dtype ,
342
+ out = query_start_loc [1 :])
343
+
344
+ slot_mapping_tensor = torch .tensor (self .slot_mapping ,
345
+ dtype = torch .long ,
346
+ device = device )
347
+
348
+ return FlashAttentionMetadata (
349
+ num_prefills = self .num_prefills ,
350
+ slot_mapping = slot_mapping_tensor ,
351
+ num_prefill_tokens = self .num_prefill_tokens ,
352
+ num_decode_tokens = num_decode_tokens ,
353
+ seq_lens = seq_lens ,
354
+ seq_lens_tensor = seq_lens_tensor ,
355
+ max_query_len = max_query_len ,
356
+ max_prefill_seq_len = max_prefill_seq_len ,
357
+ max_decode_seq_len = max_decode_seq_len ,
358
+ query_start_loc = query_start_loc ,
359
+ seq_start_loc = seq_start_loc ,
360
+ context_lens_tensor = context_lens_tensor ,
361
+ block_tables = block_tables ,
362
+ use_cuda_graph = use_captured_graph ,
363
+ )
364
+
365
+
187
366
class FlashAttentionImpl (AttentionImpl ):
188
367
"""
189
368
If the input tensors contain prompt tokens, the layout is as follows:
0 commit comments