Skip to content

Commit c530f82

Browse files
heheda12345mawong-amd
authored andcommitted
[v1] Move block management logic from KVCacheManager to SpecializedManager (vllm-project#17474)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent cf5161a commit c530f82

6 files changed

+269
-155
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def test_allocate_with_lookahead():
539539
max_model_len=100)
540540
blocks = kv_cache_manager.allocate_slots(
541541
request,
542-
num_tokens=3,
542+
num_new_tokens=3,
543543
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
544544
)
545545
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
@@ -550,7 +550,7 @@ def test_allocate_with_lookahead():
550550
# required_blocks = ceil((3 + 2) /4) = 2
551551
blocks = kv_cache_manager.allocate_slots(
552552
request,
553-
num_tokens=3,
553+
num_new_tokens=3,
554554
num_lookahead_tokens=2,
555555
)
556556
assert len(blocks.blocks) == 2
@@ -561,7 +561,7 @@ def test_allocate_with_lookahead():
561561
max_model_len=100)
562562
blocks = kv_cache_manager.allocate_slots(
563563
request,
564-
num_tokens=3,
564+
num_new_tokens=3,
565565
num_lookahead_tokens=4,
566566
)
567567
assert len(blocks.blocks) == 2

tests/v1/core/test_prefix_caching.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ def test_decode():
299299
req0.append_output_token_ids(8)
300300
new_blocks = manager.allocate_slots(req0, 4)
301301
assert new_blocks is not None and len(new_blocks.blocks) == 0
302-
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
302+
assert manager.single_type_manager.req_to_blocks[
303+
req0.request_id][-1].block_hash is None
303304

304305
# Append slots with allocating a new block.
305306
req0.num_computed_tokens = 59
@@ -309,8 +310,10 @@ def test_decode():
309310
req0.append_output_token_ids(7)
310311
new_blocks = manager.allocate_slots(req0, 19)
311312
assert new_blocks is not None and len(new_blocks.blocks) == 1
312-
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
313-
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
313+
assert manager.single_type_manager.req_to_blocks[
314+
req0.request_id][-2].block_hash is not None
315+
assert manager.single_type_manager.req_to_blocks[
316+
req0.request_id][-1].block_hash is None
314317

315318

316319
def test_evict():
@@ -689,15 +692,15 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
689692
assert not computed_blocks.blocks
690693
assert num_computed_tokens == 0
691694
manager.allocate_slots(req0, 48, computed_blocks)
692-
block_part0 = manager.req_to_blocks[req0.request_id]
695+
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
693696

694697
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
695698
req1 = make_request("1", common_token_ids * 2)
696699
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
697700
assert computed_blocks.blocks == block_part0
698701
assert num_computed_tokens == 3 * 16
699702
manager.allocate_slots(req1, 48, computed_blocks)
700-
block_part1 = manager.req_to_blocks[req1.request_id]
703+
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
701704
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
702705
# | Req1-5(F)| ... |
703706
manager.free(req1)

tests/v1/core/test_scheduler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -812,10 +812,11 @@ def _assert_right_kv_cache_manager(
812812
# Make sure the request stats are right.
813813
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
814814
for req_id in req_ids:
815-
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
815+
blocks = (scheduler.kv_cache_manager.single_type_manager.
816+
req_to_blocks[req_id])
816817
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
817-
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
818-
EXPECTED_TOTAL_BLOCKS)
818+
assert (scheduler.kv_cache_manager.single_type_manager.
819+
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
819820
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
820821
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
821822

@@ -1195,9 +1196,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
11951196
assert len(scheduler.encoder_cache_manager.cached) == 0
11961197

11971198
# KVCache Manager.
1198-
assert len(scheduler.kv_cache_manager.req_to_blocks) == 0
1199+
assert len(
1200+
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
11991201
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
1200-
assert len(scheduler.kv_cache_manager.num_cached_block) == 0
1202+
assert len(
1203+
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
12011204
num_free_blocks = (
12021205
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
12031206
assert num_free_blocks == (

tests/v1/core/test_specialized_manager.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88
from vllm.v1.kv_cache_interface import SlidingWindowSpec
99

1010

11+
def get_sliding_window_manager(sliding_window_spec, block_pool):
12+
return SlidingWindowManager(sliding_window_spec,
13+
block_pool,
14+
use_eagle=False,
15+
num_kv_cache_groups=1,
16+
caching_hash_fn=lambda x: x)
17+
18+
1119
def test_sliding_window_possible_cached_prefix():
1220
sliding_window_spec = SlidingWindowSpec(
1321
block_size=2,
@@ -19,9 +27,7 @@ def test_sliding_window_possible_cached_prefix():
1927
)
2028

2129
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
22-
manager = SlidingWindowManager(sliding_window_spec,
23-
block_pool,
24-
use_eagle=False)
30+
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
2531

2632
def run_one_case(block_is_cached, expect_length):
2733
block_hash_list = [
@@ -81,9 +87,7 @@ def test_sliding_window_remove_skipped_blocks():
8187

8288
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
8389

84-
manager = SlidingWindowManager(sliding_window_spec,
85-
block_pool,
86-
use_eagle=False)
90+
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
8791

8892
null_block_id = block_pool.null_block.block_id
8993

@@ -104,39 +108,35 @@ def assert_block_id(block_table, ids):
104108
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
105109
]
106110
block_table = id_to_block_table(original_block_ids)
107-
removed = manager.remove_skipped_blocks(block_table, 0)
108-
assert_block_id(removed, [])
111+
manager.req_to_blocks["test"] = block_table
112+
113+
manager.remove_skipped_blocks("test", 0)
109114
assert_block_id(block_table, original_block_ids)
110115

111116
# 4 tokens are computed. Only token 0 is out of the sliding window. As
112117
# block 1000 also contains token 1 that is in the sliding window, block 1000
113118
# cannot be removed.
114-
removed = manager.remove_skipped_blocks(block_table, 4)
115-
assert_block_id(removed, [])
119+
manager.remove_skipped_blocks("test", 4)
116120
assert_block_id(block_table, original_block_ids)
117121

118122
# 5 tokens are computed. Token 0 & 1 are out of the sliding window.
119123
# Block 1000 can be removed.
120-
removed = manager.remove_skipped_blocks(block_table, 5)
121-
assert_block_id(removed, [original_block_ids[0]])
124+
manager.remove_skipped_blocks("test", 5)
122125
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
123126

124127
# 6 tokens are computed. Token 0-2 are out of the sliding window.
125128
# Cannot remove new block as the block 1001 is still used by token 3.
126-
removed = manager.remove_skipped_blocks(block_table, 6)
127-
assert_block_id(removed, [])
129+
manager.remove_skipped_blocks("test", 6)
128130
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
129131

130132
# 7 tokens are computed. Token 0-3 are out of the sliding window.
131133
# Block 1001 can be removed and block 1000 is already removed.
132-
removed = manager.remove_skipped_blocks(block_table, 7)
133-
assert_block_id(removed, [original_block_ids[1]])
134+
manager.remove_skipped_blocks("test", 7)
134135
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
135136

136137
# 11 tokens are computed. Token 0-7 are out of the sliding window.
137138
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
138139
# sequence, and is expected to be evicted earlier than 1002, so the order
139140
# of removed blocks should be [1003, 1002].
140-
removed = manager.remove_skipped_blocks(block_table, 11)
141-
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
141+
manager.remove_skipped_blocks("test", 11)
142142
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])

0 commit comments

Comments
 (0)