@@ -187,17 +187,19 @@ def get_num_common_prefix_blocks(self, request_id: str,
187
187
raise NotImplementedError
188
188
189
189
@abstractmethod
190
- def find_longest_cache_hit (
191
- self , block_hashes : list [ BlockHashType ] ) -> list [KVCacheBlock ]:
190
+ def find_longest_cache_hit (self , block_hashes : list [ BlockHashType ],
191
+ max_length : int ) -> list [KVCacheBlock ]:
192
192
"""
193
- Get the longest cache hit prefix of the blocks. If no cache hit is
194
- found, return an empty list. if eagle is enabled, drop the last matched
195
- block to force recompute the last block to get the required hidden
196
- states for eagle drafting head. Need to be customized for each attention
197
- type.
193
+ Get the longest cache hit prefix of the blocks that is not longer than
194
+ `max_length`. If no cache hit is found, return an empty list.
195
+ If eagle is enabled, drop the last matched block to force recompute the
196
+ last block to get the required hidden states for eagle drafting head.
197
+ Need to be customized for each attention type.
198
198
199
199
Args:
200
200
block_hashes: The block hashes of the request.
201
+ max_length: The maximum length of the cache hit prefix.
202
+
201
203
Returns:
202
204
A list of cached blocks with skipped blocks replaced by null block.
203
205
For example, sliding window manager should return a list like
@@ -226,10 +228,12 @@ def remove_skipped_blocks(self, request_id: str,
226
228
227
229
class FullAttentionManager (SingleTypeKVCacheManager ):
228
230
229
- def find_longest_cache_hit (
230
- self , block_hashes : list [ BlockHashType ] ) -> list [KVCacheBlock ]:
231
+ def find_longest_cache_hit (self , block_hashes : list [ BlockHashType ],
232
+ max_length : int ) -> list [KVCacheBlock ]:
231
233
computed_blocks : list [KVCacheBlock ] = []
232
- for block_hash in block_hashes :
234
+ max_num_blocks = max_length // self .block_size
235
+ for i in range (max_num_blocks ):
236
+ block_hash = block_hashes [i ]
233
237
# block_hashes is a chain of block hashes. If a block hash is not
234
238
# in the cached_block_hash_to_id, the following block hashes are
235
239
# not computed yet for sure.
@@ -276,19 +280,20 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
276
280
self .sliding_window_contiguous_blocks += 1
277
281
self ._null_block = block_pool .null_block
278
282
279
- def find_longest_cache_hit (
280
- self , block_hashes : list [ BlockHashType ] ) -> list [KVCacheBlock ]:
283
+ def find_longest_cache_hit (self , block_hashes : list [ BlockHashType ],
284
+ max_length : int ) -> list [KVCacheBlock ]:
281
285
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
282
- # optimize the time complexity from O(len(block_hashes) ) to
283
- # O(len(block_hashes) / sliding_window_contiguous_blocks +
286
+ # optimize the time complexity from O(max_num_blocks ) to
287
+ # O(max_num_blocks / sliding_window_contiguous_blocks +
284
288
# sliding_window_contiguous_blocks),
285
289
# which is good for low cache hit rate scenarios.
286
- computed_blocks = [self ._null_block ] * len (block_hashes )
290
+ max_num_blocks = max_length // self .block_size
291
+ computed_blocks = [self ._null_block ] * max_num_blocks
287
292
num_contiguous_blocks = 0
288
293
289
294
match_found = False
290
295
# Search from right to left and early stop when a match is found.
291
- for i in range (len ( block_hashes ) - 1 , - 1 , - 1 ):
296
+ for i in range (max_num_blocks - 1 , - 1 , - 1 ):
292
297
if cached_block := self .block_pool .get_cached_block (
293
298
block_hashes [i ]):
294
299
computed_blocks [i ] = cached_block
0 commit comments