Skip to content

Commit 7aa4f85

Browse files
authored
[Bugfix][kvcache] revert multiple kv cache groups (#923)
Revert multiple kv cache groups related changes as this feature is reverted in vllm vllm-project/vllm#18459 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent b4d6672 commit 7aa4f85

File tree

3 files changed

+34
-30
lines changed

3 files changed

+34
-30
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from vllm.v1.worker.gpu_input_batch import InputBatch
3131

3232
from vllm_ascend.ops.attention import vanilla_chunked_prefill
33-
from vllm_ascend.utils import vllm_version_is
3433

3534

3635
class AscendAttentionBackend(AttentionBackend):
@@ -141,14 +140,8 @@ def reorder_batch(self, input_batch: "InputBatch",
141140

142141
def build(self, num_reqs, num_actual_tokens, max_query_len,
143142
common_prefix_len):
144-
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
145-
block_table = (self.runner.input_batch.block_table.
146-
get_device_tensor()[:num_reqs])
147-
else:
148-
block_table = self.runner.input_batch.block_table[
149-
0].get_device_tensor()
150-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
151-
block_table[:num_reqs])
143+
block_table = (
144+
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
152145

153146
query_lens = self.runner.query_lens
154147
seq_lens = self.runner.seq_lens_cpu[:num_reqs]

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1818
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
19-
from vllm_ascend.utils import vllm_version_is
2019
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2120

2221
if TYPE_CHECKING:
@@ -239,12 +238,8 @@ def build(self,
239238
# function. We should avoid GPU -> CPU sync as much as possible because
240239
# it blocks on all previous kernels.
241240
device = self.runner.device
242-
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
243-
block_table = (self.runner.input_batch.block_table.
244-
get_device_tensor()[:num_reqs])
245-
else:
246-
block_table = (self.runner.input_batch.block_table[0].
247-
get_device_tensor()[:num_reqs])
241+
block_table = (
242+
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
248243
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
249244
device, non_blocking=True)
250245
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,24 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
164164
raise NotImplementedError(
165165
"Non-Attention backend is not supported by V1 NPUModelRunner.")
166166

167+
self.attn_backend = get_attn_backend(
168+
self.head_size,
169+
self.dtype,
170+
self.kv_cache_dtype,
171+
self.block_size,
172+
self.model_config.is_attention_free,
173+
use_mla=self.model_config.use_mla,
174+
)
175+
if self.attn_backend is None:
176+
error_msg = (
177+
f"Error with get_att_backend: {self.head_size=}, "
178+
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
179+
f"{self.model_config.is_attention_free=}, "
180+
f"{self.model_config.use_mla=}")
181+
logger.error(error_msg)
182+
raise NotImplementedError(
183+
"Non-Attention backend is not supported by V1 GPUModelRunner.")
184+
167185
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
168186
weakref.proxy(self))
169187

@@ -196,6 +214,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
196214
pin_memory=True,
197215
vocab_size=self.model_config.get_vocab_size(),
198216
)
217+
else:
218+
self.input_batch = InputBatch(
219+
max_num_reqs=self.max_num_reqs,
220+
max_model_len=self.model_config.max_model_len,
221+
max_num_blocks_per_req=self.max_num_blocks_per_req,
222+
max_num_batched_tokens=self.max_num_tokens,
223+
device=self.device,
224+
pin_memory=True,
225+
vocab_size=self.model_config.get_vocab_size(),
226+
)
227+
199228
self.input_ids = torch.zeros(self.max_num_tokens,
200229
dtype=torch.int32,
201230
device=self.device)
@@ -542,10 +571,7 @@ def _process_reqs(
542571

543572
block_table_indices = (req_indices * self.max_num_blocks_per_req +
544573
positions_np // self.block_size)
545-
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
546-
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
547-
else:
548-
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
574+
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
549575
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
550576
block_offsets = positions_np % self.block_size
551577
np.add(block_numbers * self.block_size,
@@ -960,16 +986,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
960986
"""
961987
import torch_npu
962988
kv_caches: Dict[str, torch.Tensor] = {}
963-
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
964-
self.input_batch = InputBatch(
965-
max_num_reqs=self.max_num_reqs,
966-
max_model_len=self.model_config.max_model_len,
967-
max_num_batched_tokens=self.max_num_tokens,
968-
device=self.device,
969-
pin_memory=True,
970-
vocab_size=self.model_config.get_vocab_size(),
971-
kv_cache_config=kv_cache_config,
972-
)
973989

974990
for kv_cache_group in kv_cache_config.kv_cache_groups:
975991
kv_cache_spec = kv_cache_group.kv_cache_spec

0 commit comments

Comments
 (0)