Skip to content

Commit e60f550

Browse files
authored
[v1] Support multiple KV cache groups in GPU model runner (#17945)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent f25e0d1 commit e60f550

16 files changed

+482
-215
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
hash_request_tokens,
2020
unify_kv_cache_configs)
2121
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
22-
KVCacheGroupSpec, KVCacheTensor)
22+
KVCacheGroupSpec, KVCacheTensor,
23+
SlidingWindowSpec)
2324
from vllm.v1.metrics.stats import PrefixCacheStats
2425
from vllm.v1.request import Request
2526

@@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
5455
num_kv_heads=2,
5556
head_size=64,
5657
dtype=torch.float32,
57-
use_mla=False):
58+
use_mla=False,
59+
sliding_window=None):
5860
return FullAttentionSpec(block_size=block_size,
5961
num_kv_heads=num_kv_heads,
6062
head_size=head_size,
6163
dtype=dtype,
62-
use_mla=use_mla)
64+
use_mla=use_mla,
65+
sliding_window=sliding_window)
6366

6467

6568
def test_none_hash():
@@ -471,6 +474,68 @@ def test_unify_kv_cache_configs():
471474
unify_kv_cache_configs(diff_kv_cache_config)
472475

473476

477+
def test_merge_kv_cache_spec():
478+
same_layer_specs = [
479+
new_kv_cache_spec(num_kv_heads=32),
480+
new_kv_cache_spec(num_kv_heads=32),
481+
]
482+
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
483+
assert merged_layer_spec.block_size == 16
484+
assert merged_layer_spec.num_kv_heads == 32
485+
assert merged_layer_spec.head_size == 64
486+
assert merged_layer_spec.dtype == torch.float32
487+
assert merged_layer_spec.sliding_window is None
488+
489+
different_layer_specs = [
490+
new_kv_cache_spec(num_kv_heads=32),
491+
new_kv_cache_spec(num_kv_heads=16),
492+
]
493+
with pytest.raises(AssertionError):
494+
different_layer_specs[0].merge(different_layer_specs)
495+
496+
full_spec = new_kv_cache_spec(num_kv_heads=32)
497+
different_type_layer_specs = [
498+
full_spec,
499+
SlidingWindowSpec(
500+
block_size=full_spec.block_size,
501+
num_kv_heads=full_spec.num_kv_heads,
502+
head_size=full_spec.head_size,
503+
dtype=full_spec.dtype,
504+
use_mla=full_spec.use_mla,
505+
sliding_window=1,
506+
),
507+
]
508+
with pytest.raises(AssertionError):
509+
different_type_layer_specs[0].merge(different_type_layer_specs)
510+
with pytest.raises(AssertionError):
511+
different_type_layer_specs[1].merge(different_type_layer_specs)
512+
513+
different_sliding_window_layer_specs = [
514+
new_kv_cache_spec(num_kv_heads=32),
515+
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
516+
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
517+
]
518+
with pytest.raises(ValueError):
519+
different_sliding_window_layer_specs[0].merge(
520+
different_sliding_window_layer_specs)
521+
522+
same_sliding_window_layer_specs = [
523+
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
524+
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
525+
]
526+
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
527+
same_sliding_window_layer_specs)
528+
assert merged_layer_spec.sliding_window == 1
529+
530+
same_sliding_window_layer_spec_with_none = [
531+
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
532+
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
533+
]
534+
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
535+
same_sliding_window_layer_spec_with_none)
536+
assert merged_layer_spec.sliding_window == 1
537+
538+
474539
@pytest.mark.parametrize(
475540
("model_id", "max_model_len", "want_estimated_max_len"), [
476541
("Qwen/Qwen1.5-7B", 16385, 16384),

tests/v1/core/test_prefix_caching.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
8484
blocks = manager.allocate_slots(req0, 55,
8585
len(computed_blocks.blocks) * 16,
8686
computed_blocks)
87-
assert blocks.get_block_ids() == [1, 2, 3, 4]
87+
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
8888

8989
# Check full block metadata
9090
parent_block_hash = None
@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
107107
req1 = make_request("1", common_token_ids + unique_token_ids)
108108
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
109109
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
110-
assert computed_blocks.get_block_ids() == [1, 2, 3]
110+
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
111111
assert num_computed_tokens == 3 * 16
112112
num_new_tokens = 53 - 3 * 16
113113
blocks = manager.allocate_slots(req1, num_new_tokens,
114114
len(computed_blocks.blocks) * 16,
115115
computed_blocks)
116-
assert blocks.get_block_ids() == [5]
116+
assert blocks.get_block_ids() == [[5]]
117117
for block in computed_blocks.blocks:
118118
assert block.ref_cnt == 2
119119

@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
141141
req2 = make_request("2", common_token_ids + unique_token_ids)
142142
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
143143
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
144-
assert computed_blocks.get_block_ids() == [1, 2, 3]
144+
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
145145
assert num_computed_tokens == 3 * 16
146146
num_new_tokens = 53 - 3 * 16
147147
blocks = manager.allocate_slots(req2, num_new_tokens,
148148
len(computed_blocks.blocks) * 16,
149149
computed_blocks)
150-
assert blocks.get_block_ids() == [6]
150+
assert blocks.get_block_ids() == [[6]]
151151

152152
# Although we only have 6 free blocks, we have 8 blocks in
153153
# the free block queue due to lazy removal.
@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
171171
len(computed_blocks.blocks) * 16,
172172
computed_blocks)
173173
# This block ID order also checks the eviction order.
174-
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
174+
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
175175
assert manager.block_pool.free_block_queue.num_free_blocks == 0
176176
assert manager.block_pool.free_block_queue.free_list_head is None
177177
assert manager.block_pool.free_block_queue.free_list_tail is None
@@ -208,7 +208,7 @@ def test_prefill_plp():
208208
blocks = manager.allocate_slots(req0, 55,
209209
len(computed_blocks.blocks) * 16,
210210
computed_blocks)
211-
assert blocks.get_block_ids() == [1, 2, 3, 4]
211+
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
212212
req0_block_hashes = [b.block_hash for b in blocks.blocks]
213213

214214
# Check full block metadata
@@ -233,13 +233,13 @@ def test_prefill_plp():
233233
req1 = make_request("1", common_token_ids + unique_token_ids)
234234
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
235235
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
236-
assert computed_blocks.get_block_ids() == [1, 2, 3]
236+
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
237237
assert num_computed_tokens == 3 * 16
238238
num_new_tokens = 53 - 3 * 16
239239
blocks = manager.allocate_slots(req1, num_new_tokens,
240240
len(computed_blocks.blocks) * 16,
241241
computed_blocks)
242-
assert blocks.get_block_ids() == [5]
242+
assert blocks.get_block_ids() == [[5]]
243243
for block in computed_blocks.blocks:
244244
assert block.ref_cnt == 2
245245

@@ -277,11 +277,11 @@ def test_prefill_plp():
277277
block_ids = blocks.get_block_ids()
278278
# Duplicate cached blocks have different ids but same hashes vs request #0
279279
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
280-
assert block_ids != [1, 2, 3, 4]
280+
assert block_ids != [[1, 2, 3, 4]]
281281

282282
# Request #2 block hashes are valid since request #0 hashes are.
283283
# Check block reference counts.
284-
for block_id in block_ids:
284+
for block_id in block_ids[0]:
285285
assert manager.block_pool.blocks[block_id].ref_cnt == 1
286286

287287
manager.free(req2)
@@ -307,7 +307,7 @@ def test_decode():
307307
blocks = manager.allocate_slots(req0, 55,
308308
len(computed_blocks.blocks) * 16,
309309
computed_blocks)
310-
assert blocks.get_block_ids() == [1, 2, 3, 4]
310+
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
311311

312312
# Append slots without allocating a new block.
313313
req0.num_computed_tokens = 55
@@ -379,12 +379,12 @@ def test_evict():
379379
# Touch the first 2 blocks.
380380
req2 = make_request("2", list(range(2 * 16 + 3)))
381381
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
382-
assert computed_blocks.get_block_ids() == [1, 2]
382+
assert computed_blocks.get_block_ids() == [[1, 2]]
383383
assert num_computed_tokens == 2 * 16
384384
blocks = manager.allocate_slots(req2, 3,
385385
len(computed_blocks.blocks) * 16,
386386
computed_blocks)
387-
assert blocks.get_block_ids() == [10]
387+
assert blocks.get_block_ids() == [[10]]
388388
assert manager.block_pool.free_block_queue.num_free_blocks == 7
389389

390390

@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
625625
blocks = manager.allocate_slots(req0, 59,
626626
len(computed_blocks.blocks) * 16,
627627
computed_blocks)
628-
assert blocks.get_block_ids() == [1, 2, 3, 4]
628+
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
629629
req0.num_computed_tokens = 59
630630

631631
# Append slots without allocating a new block.
@@ -686,7 +686,7 @@ def test_cache_key_salting():
686686
blocks = manager.allocate_slots(req0, 59,
687687
len(computed_blocks.blocks) * 16,
688688
computed_blocks)
689-
assert blocks.get_block_ids() == [1, 2, 3, 4]
689+
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
690690
req0.num_computed_tokens = 59
691691

692692
# Append slots without allocating a new block.
@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
797797
all_token_ids = full_block_token_ids + unique_token_ids
798798
req0 = make_request("0", all_token_ids)
799799
blocks = manager.allocate_slots(req0, 55)
800-
assert blocks.get_block_ids() == [1, 2, 3, 4]
800+
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
801801

802802
unique_token_ids = [4] * 7
803803
all_token_ids = full_block_token_ids + unique_token_ids
@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
808808
blocks = manager.allocate_slots(req1, 7,
809809
len(computed_blocks.blocks) * 16,
810810
computed_blocks)
811-
assert blocks.get_block_ids() == [5]
811+
assert blocks.get_block_ids() == [[5]]
812812

813813
# Failed to reset prefix cache because some blocks are not freed yet.
814814
assert not manager.reset_prefix_cache()

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
from vllm.sampling_params import SamplingParams
1111
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
12+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
13+
KVCacheGroupSpec, KVCacheTensor)
1214
from vllm.v1.sample.metadata import SamplingMetadata
13-
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
14-
InputBatch)
15+
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
16+
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
1517

1618
VOCAB_SIZE = 1024
1719
NUM_OUTPUT_TOKENS = 20
@@ -22,6 +24,27 @@
2224
MAX_NUM_PROMPT_TOKENS = 64
2325

2426

27+
def get_kv_cache_config() -> KVCacheConfig:
28+
return KVCacheConfig(
29+
num_blocks=10,
30+
tensors={
31+
"layer.0": KVCacheTensor(size=1024),
32+
},
33+
kv_cache_groups=[
34+
KVCacheGroupSpec(
35+
layer_names=["layer.0"],
36+
kv_cache_spec=FullAttentionSpec(
37+
block_size=1,
38+
num_kv_heads=1,
39+
head_size=16,
40+
dtype=torch.float16,
41+
use_mla=False,
42+
),
43+
),
44+
],
45+
)
46+
47+
2548
def _compare_objs(obj1, obj2):
2649
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
2750
attr_names = set([
@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
4164
elif isinstance(a, np.ndarray):
4265
if np.allclose(a, b):
4366
is_same = True
67+
elif isinstance(a, MultiGroupBlockTable):
68+
for a_i, b_i in zip(a.block_tables, b.block_tables):
69+
_compare_objs(a_i, b_i)
70+
is_same = True
4471
elif isinstance(a, (BlockTable, SamplingMetadata)):
4572
_compare_objs(a, b)
4673
is_same = True # if we make it here must be same
@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
198225
sampling_params=_create_sampling_params(),
199226
mm_inputs=[],
200227
mm_positions=[],
201-
block_ids=[],
228+
block_ids=[[]],
202229
generator=None,
203230
num_computed_tokens=len(output_token_ids),
204231
output_token_ids=output_token_ids,
@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
220247
input_batch: InputBatch = InputBatch(
221248
max_num_reqs=batch_size,
222249
max_model_len=1024,
223-
max_num_blocks_per_req=10,
224250
max_num_batched_tokens=1024,
225251
device=torch.device(device),
226252
pin_memory=is_pin_memory_available(),
227253
vocab_size=1024,
254+
kv_cache_config=get_kv_cache_config(),
228255
)
229256
reqs: list[CachedRequestState] = []
230257
req_id_reqs = {}
@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
310337
input_batch: InputBatch = InputBatch(
311338
max_num_reqs=batch_size,
312339
max_model_len=1024,
313-
max_num_blocks_per_req=10,
314340
max_num_batched_tokens=1024,
315341
device=torch.device(device),
316342
pin_memory=is_pin_memory_available(),
317343
vocab_size=1024,
344+
kv_cache_config=get_kv_cache_config(),
318345
)
319346
ref_input_batch: InputBatch = InputBatch(
320347
max_num_reqs=batch_size,
321348
max_model_len=1024,
322-
max_num_blocks_per_req=10,
323349
max_num_batched_tokens=1024,
324350
device=torch.device(device),
325351
pin_memory=is_pin_memory_available(),
326352
vocab_size=1024,
353+
kv_cache_config=get_kv_cache_config(),
327354
)
328355

329356
reqs: list[CachedRequestState] = []

0 commit comments

Comments
 (0)