Skip to content

Commit f0d610a

Browse files
[v1][KVCacheManager] Avoid full cache hit by controlling max_length (#17999)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent e57e4d6 commit f0d610a

File tree

3 files changed

+36
-39
lines changed

3 files changed

+36
-39
lines changed

tests/v1/core/test_specialized_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
1717

1818

1919
def test_sliding_window_possible_cached_prefix():
20+
block_size = 2
2021
sliding_window_spec = SlidingWindowSpec(
21-
block_size=2,
22+
block_size=block_size,
2223
num_kv_heads=1,
2324
head_size=1,
2425
dtype=torch.float32,
@@ -44,7 +45,9 @@ def run_one_case(block_is_cached, expect_length):
4445
i: block_pool.blocks[i + 10]
4546
}
4647

47-
computed_blocks = manager.find_longest_cache_hit(block_hash_list)
48+
computed_blocks = manager.find_longest_cache_hit(
49+
block_hash_list,
50+
len(block_hash_list) * block_size)
4851
assert len(computed_blocks) == expect_length
4952

5053
assert all(block == block_pool.null_block

vllm/v1/core/kv_cache_manager.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,16 @@ def get_computed_blocks(self,
146146
assert self.prefix_cache_stats is not None
147147
self.prefix_cache_stats.requests += 1
148148

149-
if len(block_hashes) * self.block_size == request.num_tokens:
150-
# When prompt length is divisible by the block size and all
151-
# blocks are cached, we need to recompute the last token. This
152-
# have to be achieved by re-computing an entire block because
153-
# allocate_slots() assumes num_computed_tokens is always a
154-
# multiple of the block size. To achieve this, remove the last
155-
# block hash from the block_hashes for find_longest_cache_hit
156-
# This limitation can potentially be removed in the future to
157-
# slightly improve the performance.
158-
last_block_hash = block_hashes.pop()
159-
else:
160-
last_block_hash = None
161-
162-
computed_blocks = (
163-
self.single_type_manager.find_longest_cache_hit(block_hashes))
149+
# NOTE: When all tokens hit the cache, we must recompute the last token
150+
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
151+
# This can trigger recomputation of an entire block, rather than just
152+
# the single last token, because allocate_slots() requires
153+
# num_computed_tokens to be block-size aligned. Removing this limitation
154+
# could slightly improve performance in the future.
155+
max_cache_hit_length = request.num_tokens - 1
156+
157+
computed_blocks = self.single_type_manager.find_longest_cache_hit(
158+
block_hashes, max_cache_hit_length)
164159
# NOTE(woosuk): Since incomplete blocks are not eligible for
165160
# sharing, `num_computed_tokens` is always a multiple of
166161
# `block_size`.
@@ -171,12 +166,6 @@ def get_computed_blocks(self,
171166
self.prefix_cache_stats.queries += request.num_tokens
172167
self.prefix_cache_stats.hits += num_computed_tokens
173168

174-
if last_block_hash is not None:
175-
# Add back the last block hash if it was removed.
176-
# NOTE: Because block_hashes is cached in req_to_block_hashes,
177-
# we shouldn't modify it directly.
178-
block_hashes.append(last_block_hash)
179-
180169
return KVCacheBlocks(computed_blocks), num_computed_tokens
181170

182171
def allocate_slots(

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,19 @@ def get_num_common_prefix_blocks(self, request_id: str,
187187
raise NotImplementedError
188188

189189
@abstractmethod
190-
def find_longest_cache_hit(
191-
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
190+
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
191+
max_length: int) -> list[KVCacheBlock]:
192192
"""
193-
Get the longest cache hit prefix of the blocks. If no cache hit is
194-
found, return an empty list. if eagle is enabled, drop the last matched
195-
block to force recompute the last block to get the required hidden
196-
states for eagle drafting head. Need to be customized for each attention
197-
type.
193+
Get the longest cache hit prefix of the blocks that is not longer than
194+
`max_length`. If no cache hit is found, return an empty list.
195+
If eagle is enabled, drop the last matched block to force recompute the
196+
last block to get the required hidden states for eagle drafting head.
197+
Need to be customized for each attention type.
198198
199199
Args:
200200
block_hashes: The block hashes of the request.
201+
max_length: The maximum length of the cache hit prefix.
202+
201203
Returns:
202204
A list of cached blocks with skipped blocks replaced by null block.
203205
For example, sliding window manager should return a list like
@@ -226,10 +228,12 @@ def remove_skipped_blocks(self, request_id: str,
226228

227229
class FullAttentionManager(SingleTypeKVCacheManager):
228230

229-
def find_longest_cache_hit(
230-
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
231+
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
232+
max_length: int) -> list[KVCacheBlock]:
231233
computed_blocks: list[KVCacheBlock] = []
232-
for block_hash in block_hashes:
234+
max_num_blocks = max_length // self.block_size
235+
for i in range(max_num_blocks):
236+
block_hash = block_hashes[i]
233237
# block_hashes is a chain of block hashes. If a block hash is not
234238
# in the cached_block_hash_to_id, the following block hashes are
235239
# not computed yet for sure.
@@ -276,19 +280,20 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
276280
self.sliding_window_contiguous_blocks += 1
277281
self._null_block = block_pool.null_block
278282

279-
def find_longest_cache_hit(
280-
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
283+
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
284+
max_length: int) -> list[KVCacheBlock]:
281285
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
282-
# optimize the time complexity from O(len(block_hashes)) to
283-
# O(len(block_hashes) / sliding_window_contiguous_blocks +
286+
# optimize the time complexity from O(max_num_blocks) to
287+
# O(max_num_blocks / sliding_window_contiguous_blocks +
284288
# sliding_window_contiguous_blocks),
285289
# which is good for low cache hit rate scenarios.
286-
computed_blocks = [self._null_block] * len(block_hashes)
290+
max_num_blocks = max_length // self.block_size
291+
computed_blocks = [self._null_block] * max_num_blocks
287292
num_contiguous_blocks = 0
288293

289294
match_found = False
290295
# Search from right to left and early stop when a match is found.
291-
for i in range(len(block_hashes) - 1, -1, -1):
296+
for i in range(max_num_blocks - 1, -1, -1):
292297
if cached_block := self.block_pool.get_cached_block(
293298
block_hashes[i]):
294299
computed_blocks[i] = cached_block

0 commit comments

Comments
 (0)