14
14
import os
15
15
import time
16
16
from array import array
17
- from enum import IntEnum
17
+ from enum import Enum , IntEnum
18
18
from typing import (TYPE_CHECKING , Any , Callable , Dict , List , NamedTuple ,
19
19
Optional , Set , Tuple , Type , TypeVar , Union )
20
20
75
75
DUMMY_TOKEN_ID = - 1
76
76
77
77
78
+ class PhaseType (Enum ):
79
+ PREFILL = 'prefill'
80
+ PREFIX_PREFILL = 'prefix_prefill'
81
+ DECODE = 'decode'
82
+
83
+
78
84
def subtuple (obj : object ,
79
85
typename : str ,
80
86
to_copy : List [str ],
@@ -213,20 +219,40 @@ def _compile_region(self, model, name, module):
213
219
214
220
def _set_attn_bias (self , attn_metadata , batch_size , seq_len , device ,
215
221
dtype ):
216
- prefill_metadata = attn_metadata
217
- if prefill_metadata is None or self .prefill_use_fusedsdpa :
222
+ if (attn_metadata is None
223
+ or (self .prefill_use_fusedsdpa \
224
+ and attn_metadata .block_list is None )
225
+ or not attn_metadata .is_prompt ):
218
226
return attn_metadata
219
227
228
+ prefill_metadata = attn_metadata
229
+
220
230
seq_lens_t = prefill_metadata .seq_lens_tensor
231
+ context_lens_t = prefill_metadata .context_lens_tensor
232
+ query_lens_t = seq_lens_t - context_lens_t
233
+
234
+ block_list = attn_metadata .block_list
235
+ max_context_len = (block_list .size (- 1 ) //
236
+ batch_size if block_list is not None else 0 )
237
+ max_context_len = max_context_len * self .block_size
238
+ past_mask = torch .arange (0 ,
239
+ max_context_len ,
240
+ dtype = torch .int32 ,
241
+ device = device )
242
+ past_mask = (past_mask .view (1 , - 1 ).expand (batch_size , - 1 ).ge (
243
+ context_lens_t .view (- 1 , 1 )).view (batch_size , 1 , - 1 ).expand (
244
+ batch_size , seq_len , - 1 ).view (batch_size , 1 , seq_len , - 1 ))
245
+
221
246
len_mask = (torch .arange (0 , seq_len , device = device ,
222
247
dtype = torch .int32 ).view (1 , seq_len ).ge (
223
- seq_lens_t .unsqueeze (- 1 )).view (
248
+ query_lens_t .unsqueeze (- 1 )).view (
224
249
batch_size , 1 , 1 , seq_len ))
225
250
causal_mask = torch .triu (torch .ones ((batch_size , 1 , seq_len , seq_len ),
226
251
device = device ,
227
252
dtype = torch .bool ),
228
253
diagonal = 1 )
229
254
mask = causal_mask .logical_or (len_mask )
255
+ mask = torch .concat ((past_mask , mask ), dim = - 1 )
230
256
attn_bias = (torch .zeros_like (mask , dtype = dtype ).masked_fill_ (
231
257
mask , - math .inf ))
232
258
attn_metadata = prefill_metadata ._replace (attn_bias = attn_bias )
@@ -517,6 +543,11 @@ def __init__(
517
543
False , self .max_model_len )
518
544
self .graphed_buckets : Set [Any ] = set ()
519
545
self ._set_gc_threshold ()
546
+ if self .vllm_config .cache_config .enable_prefix_caching :
547
+ os .environ .setdefault ("VLLM_CONTIGUOUS_PA" , "False" )
548
+ assert os .environ .get (
549
+ "VLLM_CONTIGUOUS_PA" ,
550
+ "" ).lower () != "true" , "Contiguous PA doesn't support APC"
520
551
self .use_contiguous_pa = envs .VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
521
552
522
553
# For multi-step scheduling
@@ -702,6 +733,10 @@ def _prepare_prompt(
702
733
computed_block_nums ) > 0 and self .sliding_window is None :
703
734
# Prefix is not supported with sliding_window
704
735
context_len = len (computed_block_nums ) * self .block_size
736
+ if context_len == seq_len \
737
+ and self .vllm_config .cache_config .enable_prefix_caching :
738
+ # Fully cached prompt - compute only last token
739
+ context_len = context_len - 1
705
740
prompt_tokens = prompt_tokens [context_len :]
706
741
prefix_block_tables .append (computed_block_nums )
707
742
elif self .scheduler_config .chunked_prefill_enabled :
@@ -779,12 +814,33 @@ def _prepare_prompt(
779
814
if lora_id > 0 :
780
815
lora_requests .add (seq_group_metadata .lora_request )
781
816
782
- lora_index_mapping += [lora_id ] * ( max_prompt_len - context_len )
817
+ lora_index_mapping += [lora_id ] * max_prompt_len
783
818
lora_prompt_mapping .extend (
784
819
[lora_id ] *
785
- (max_prompt_len - context_len
820
+ (max_prompt_len
786
821
if seq_group_metadata .sampling_params .prompt_logprobs else 1 ))
787
822
823
+ if any (context_lens ):
824
+ assert not self .scheduler_config .chunked_prefill_enabled
825
+ # prefix caching
826
+
827
+ max_num_block = max (len (bt ) for bt in prefix_block_tables )
828
+ prefix_block_list = list (
829
+ itertools .chain .from_iterable (
830
+ bt if len (bt ) == max_num_block else bt +
831
+ ([_PAD_BLOCK_ID ] * (max_num_block - len (bt )))
832
+ for bt in prefix_block_tables ))
833
+
834
+ pad_len = len (prefix_block_list )
835
+ prefix_block_list = pad_list (prefix_block_list , pad_len ,
836
+ _PAD_BLOCK_ID )
837
+
838
+ prefix_block_list_tensor = torch .tensor (prefix_block_list ,
839
+ dtype = torch .long ,
840
+ device = self .device )
841
+ else :
842
+ prefix_block_list_tensor = None
843
+
788
844
input_tokens = make_tensor_with_pad (input_tokens ,
789
845
max_len = max_prompt_len ,
790
846
pad = 0 ,
@@ -807,18 +863,23 @@ def _prepare_prompt(
807
863
dtype = torch .long ,
808
864
device = self .device )
809
865
866
+ context_lens_tensor = torch .tensor (context_lens ,
867
+ dtype = torch .long ,
868
+ device = self .device )
869
+
810
870
block_indices , block_offsets = precompute_indices_and_offsets (
811
871
self .block_size , slot_mapping , True )
812
872
attn_metadata = self .attn_backend .make_metadata (
813
873
is_prompt = True ,
814
- block_list = None ,
874
+ block_list = prefix_block_list_tensor ,
815
875
block_mapping = None ,
816
876
block_usage = None ,
817
877
block_indices = block_indices ,
818
878
block_offsets = block_offsets ,
819
879
block_groups = None ,
820
880
attn_bias = None ,
821
881
seq_lens_tensor = seq_lens_tensor ,
882
+ context_lens_tensor = context_lens_tensor ,
822
883
num_prefills = real_num_seqs ,
823
884
num_prefill_tokens = sum_query_len ,
824
885
num_decode_tokens = 0 ,
@@ -987,6 +1048,7 @@ def _prepare_decode(
987
1048
block_groups = block_groups ,
988
1049
attn_bias = None ,
989
1050
seq_lens_tensor = None ,
1051
+ context_lens_tensor = None ,
990
1052
num_prefills = 0 ,
991
1053
num_prefill_tokens = 0 ,
992
1054
num_decode_tokens = num_decode_tokens ,
@@ -1091,7 +1153,7 @@ def prepare_input_tensors(
1091
1153
# FIXME: We need to adjust selected_token_indices to accommodate
1092
1154
# for padding
1093
1155
max_len = input_tokens .size (1 )
1094
- paddings = [max_len - s for s in seq_lens ]
1156
+ paddings = [max_len - q for q in query_lens ]
1095
1157
paddings = [0 ] + paddings [:- 1 ]
1096
1158
paddings = list (itertools .accumulate (paddings ))
1097
1159
paddings_prompt_logprobs = []
@@ -1187,9 +1249,17 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
1187
1249
# input_hash(123) != input_hash(321)
1188
1250
# input_hash("abc") != input_hash("cba")
1189
1251
attention_metadata = subtuple (metadata , 'TrimmedAttentionMetadata' , [
1190
- 'attn_bias' , 'seq_lens_tensor' , 'block_list' , 'block_mapping' ,
1191
- 'block_usage' , 'slot_mapping' , 'is_prompt' , 'block_indices' ,
1192
- 'block_offsets' , 'block_groups'
1252
+ 'attn_bias' ,
1253
+ 'seq_lens_tensor' ,
1254
+ 'context_lens_tensor' ,
1255
+ 'block_list' ,
1256
+ 'block_mapping' ,
1257
+ 'block_usage' ,
1258
+ 'slot_mapping' ,
1259
+ 'is_prompt' ,
1260
+ 'block_indices' ,
1261
+ 'block_offsets' ,
1262
+ 'block_groups' ,
1193
1263
])
1194
1264
return attention_metadata
1195
1265
@@ -1733,14 +1803,44 @@ def finish_measurements(self):
1733
1803
from neural_compressor .torch .quantization import finalize_calibration
1734
1804
finalize_calibration (self .model .model )
1735
1805
1736
- def _check_config (self , batch_size , seq_len , is_prompt , warmup_mode ):
1737
- cfg = (batch_size , seq_len , is_prompt )
1806
+ def _num_blocks (self , attn_metadata ):
1807
+ if attn_metadata .block_list is None :
1808
+ return 0
1809
+ return attn_metadata .block_list .numel ()
1810
+
1811
+ def _phase (self , attn_metadata ):
1812
+ phase_type : PhaseType
1813
+ is_prompt = attn_metadata .is_prompt
1814
+ is_prefix_prefill = is_prompt and attn_metadata .block_list is not None
1815
+ if is_prompt and is_prefix_prefill :
1816
+ phase_type = PhaseType .PREFIX_PREFILL
1817
+ elif is_prompt and not is_prefix_prefill :
1818
+ phase_type = PhaseType .PREFILL
1819
+ elif not is_prompt :
1820
+ phase_type = PhaseType .DECODE
1821
+ else :
1822
+ raise ValueError ("Unrecognized pass type, likely due to malformed "
1823
+ "attention metadata" )
1824
+ return phase_type
1825
+
1826
+ def _check_config (self , batch_size , seq_len , attn_metadata , warmup_mode ):
1827
+ is_prefix_caching = self .vllm_config .cache_config .enable_prefix_caching
1828
+ cfg : Optional [tuple ] = None
1829
+ assert cfg is None , "Configs changed between 2D and 3D"
1830
+ if is_prefix_caching :
1831
+ phase = self ._phase (attn_metadata )
1832
+ num_blocks = self ._num_blocks (attn_metadata )
1833
+ cfg = (batch_size , seq_len , num_blocks , phase )
1834
+ else :
1835
+ phase = 'prompt' if attn_metadata .is_prompt else 'decode'
1836
+ cfg = (batch_size , seq_len , phase )
1738
1837
seen = cfg in self .seen_configs
1739
1838
self .seen_configs .add (cfg )
1740
1839
if not seen and not warmup_mode :
1741
- phase = 'prompt' if is_prompt else 'decode'
1742
- logger .warning ("Configuration: (%s, %s, %s) was not warmed-up!" ,
1743
- phase , batch_size , seq_len )
1840
+ logger .warning ("Configuration: %s was not warmed-up!" ,
1841
+ (phase .value , batch_size , seq_len ,
1842
+ num_blocks ) if is_prefix_caching else
1843
+ (phase , batch_size , seq_len ))
1744
1844
1745
1845
def create_lora_mask (self , input_tokens : torch .Tensor , lora_ids : List [int ],
1746
1846
is_prompt : bool ):
@@ -1912,7 +2012,7 @@ def execute_model(
1912
2012
batch_size = input_tokens .size (0 )
1913
2013
seq_len = self ._seq_len (attn_metadata )
1914
2014
use_graphs = self ._use_graphs (batch_size , seq_len , is_prompt )
1915
- self ._check_config (batch_size , seq_len , is_prompt , warmup_mode )
2015
+ self ._check_config (batch_size , seq_len , attn_metadata , warmup_mode )
1916
2016
1917
2017
lora_mask : torch .Tensor = None
1918
2018
lora_logits_mask : torch .Tensor = None
0 commit comments