Skip to content

Commit 973f993

Browse files
authored
[Misc] fix initialize_kv_cache (#1102)
KV cache manger has been changed by vllm-project/vllm@f8a1a2d This PR adapt the change into vllm-ascend to make ci happy Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent c94afd7 commit 973f993

File tree

3 files changed

+54
-16
lines changed

3 files changed

+54
-16
lines changed

tests/singlecard/test_scheduler.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm.sampling_params import SamplingParams
2626
from vllm.v1.core.sched.output import SchedulerOutput
2727
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
28-
KVCacheGroupSpec)
28+
KVCacheGroupSpec, KVCacheTensor)
2929
from vllm.v1.outputs import ModelRunnerOutput
3030
from vllm.v1.request import Request, RequestStatus
3131
from vllm.v1.structured_output import StructuredOutputManager
@@ -88,14 +88,26 @@ def create_scheduler(
8888
model_config=model_config,
8989
cache_config=cache_config)
9090

91-
kv_cache_config = KVCacheConfig(
92-
num_blocks=10000, # A large number of blocks to hold all requests
93-
tensors={},
94-
kv_cache_groups=[
95-
KVCacheGroupSpec(['layer'],
96-
FullAttentionSpec(16, 1, 1, torch.float32, False))
97-
],
98-
)
91+
if vllm_version_is("0.9.0"):
92+
kv_cache_config = KVCacheConfig(
93+
num_blocks=10000, # A large number of blocks to hold all requests
94+
tensors={},
95+
kv_cache_groups=[
96+
KVCacheGroupSpec(['layer'],
97+
FullAttentionSpec(16, 1, 1, torch.float32,
98+
False))
99+
],
100+
)
101+
else:
102+
kv_cache_config = KVCacheConfig(
103+
num_blocks=10000, # A large number of blocks to hold all requests
104+
kv_cache_tensors=[KVCacheTensor(size=1024, shared_by=[1])],
105+
kv_cache_groups=[
106+
KVCacheGroupSpec(['layer'],
107+
FullAttentionSpec(16, 1, 1, torch.float32,
108+
False, None))
109+
],
110+
)
99111
cache_config.num_gpu_blocks = 10000
100112
return AscendScheduler(
101113
vllm_config,

vllm_ascend/core/scheduler.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from vllm.v1.request import Request, RequestStatus
3030
from vllm.v1.structured_output import StructuredOutputManager
3131

32+
from vllm_ascend.utils import vllm_version_is
33+
3234

3335
class AscendScheduler(Scheduler):
3436
"""This Scheduler extends vllm's original v1 scheduler
@@ -127,10 +129,15 @@ def skip_cur_request():
127129
continue
128130

129131
assert num_new_tokens > 0
132+
133+
if vllm_version_is("0.9.0"):
134+
blocks = computed_blocks.blocks
135+
else:
136+
blocks = computed_blocks.blocks[0]
137+
130138
watermark = getattr(self.scheduler_config, "watermark", 0.01)
131139
if not self._check_watermark_for_prefill(request, num_new_tokens,
132-
computed_blocks.blocks,
133-
watermark):
140+
blocks, watermark):
134141
# Scheduling would exceed watermark, skip.
135142
skip_cur_request()
136143
continue
@@ -323,8 +330,14 @@ def _check_watermark_for_prefill(self,
323330
len(computed_blocks) * self.block_size)
324331
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
325332
self.block_size)
326-
req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[
327-
request.request_id]
333+
334+
if vllm_version_is("0.9.0"):
335+
req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[
336+
request.request_id]
337+
else:
338+
req_blocks = self.kv_cache_manager.coordinator.get_blocks(
339+
request.request_id)
340+
328341
num_new_blocks = (num_required_blocks - len(req_blocks) -
329342
len(computed_blocks))
330343
num_evictable_computed_blocks = sum(1 for blk in computed_blocks

vllm_ascend/worker/model_runner_v1.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,12 +1321,25 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
13211321
block_sizes=[self.cache_config.block_size],
13221322
)
13231323

1324+
if not vllm_version_is("0.9.0"):
1325+
kv_cache_sizes = {}
1326+
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
1327+
assert len(kv_cache_tensor.shared_by) == 1, (
1328+
"KV cache tensor shared by multiple layers is not supported in "
1329+
"NPU.")
1330+
kv_cache_sizes[
1331+
kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
1332+
13241333
for kv_cache_group in kv_cache_config.kv_cache_groups:
13251334
kv_cache_spec = kv_cache_group.kv_cache_spec
13261335
for layer_name in kv_cache_group.layer_names:
1327-
tensor_config = kv_cache_config.tensors[layer_name]
1328-
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
1329-
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
1336+
if vllm_version_is("0.9.0"):
1337+
tensor_size = kv_cache_config.tensors[layer_name].size
1338+
else:
1339+
tensor_size = kv_cache_sizes[layer_name]
1340+
assert tensor_size % kv_cache_spec.page_size_bytes == 0
1341+
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
1342+
13301343
# `num_blocks` is the number of blocks the model runner can use.
13311344
# `kv_cache_config.num_blocks` is the number of blocks that
13321345
# KVCacheManager may allocate.

0 commit comments

Comments
 (0)