Skip to content

Commit a0c3e9b

Browse files
authored
[Bugfix] Adjust inputbatch to be compatible with latest vllm (#945)
Adjust inputbatch to be compatible with latest vllm, as kvcache group feature has been redo in vllm-project/vllm#18593 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 1f9fb86 commit a0c3e9b

File tree

3 files changed

+33
-33
lines changed

3 files changed

+33
-33
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.v1.worker.gpu_input_batch import InputBatch
3131

3232
from vllm_ascend.ops.attention import vanilla_chunked_prefill
33+
from vllm_ascend.utils import vllm_version_is
3334

3435

3536
class AscendAttentionBackend(AttentionBackend):
@@ -141,8 +142,14 @@ def reorder_batch(self, input_batch: "InputBatch",
141142

142143
def build(self, num_reqs, num_actual_tokens, max_query_len,
143144
common_prefix_len):
144-
block_table = (
145-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
145+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
146+
block_table = (self.runner.input_batch.block_table.
147+
get_device_tensor()[:num_reqs])
148+
else:
149+
block_table = self.runner.input_batch.block_table[
150+
0].get_device_tensor()
151+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
152+
block_table[:num_reqs])
146153

147154
query_lens = self.runner.query_lens
148155
seq_lens = self.runner.seq_lens_cpu[:num_reqs]

vllm_ascend/attention/mla_v1.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1818
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
19+
from vllm_ascend.utils import vllm_version_is
1920
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2021

2122
if TYPE_CHECKING:
@@ -238,8 +239,14 @@ def build(self,
238239
# function. We should avoid GPU -> CPU sync as much as possible because
239240
# it blocks on all previous kernels.
240241
device = self.runner.device
241-
block_table = (
242-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
242+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
243+
block_table = (self.runner.input_batch.block_table.
244+
get_device_tensor()[:num_reqs])
245+
else:
246+
block_table = self.runner.input_batch.block_table[
247+
0].get_device_tensor()
248+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
249+
block_table[:num_reqs])
243250
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
244251
device, non_blocking=True)
245252
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
114114
def __init__(self, vllm_config: VllmConfig, device: torch.device):
115115
self.vllm_config = vllm_config
116116
self.model_config = vllm_config.model_config
117+
self.cache_config = vllm_config.cache_config
117118
self.lora_config = vllm_config.lora_config
118119
self.scheduler_config = vllm_config.scheduler_config
119120
self.speculative_config = vllm_config.speculative_config
@@ -172,24 +173,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
172173
raise NotImplementedError(
173174
"Non-Attention backend is not supported by V1 NPUModelRunner.")
174175

175-
self.attn_backend = get_attn_backend(
176-
self.head_size,
177-
self.dtype,
178-
self.kv_cache_dtype,
179-
self.block_size,
180-
self.model_config.is_attention_free,
181-
use_mla=self.model_config.use_mla,
182-
)
183-
if self.attn_backend is None:
184-
error_msg = (
185-
f"Error with get_att_backend: {self.head_size=}, "
186-
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
187-
f"{self.model_config.is_attention_free=}, "
188-
f"{self.model_config.use_mla=}")
189-
logger.error(error_msg)
190-
raise NotImplementedError(
191-
"Non-Attention backend is not supported by V1 GPUModelRunner.")
192-
193176
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
194177
weakref.proxy(self))
195178

@@ -237,16 +220,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
237220
pin_memory=True,
238221
vocab_size=self.model_config.get_vocab_size(),
239222
)
240-
else:
241-
self.input_batch = InputBatch(
242-
max_num_reqs=self.max_num_reqs,
243-
max_model_len=self.model_config.max_model_len,
244-
max_num_blocks_per_req=self.max_num_blocks_per_req,
245-
max_num_batched_tokens=self.max_num_tokens,
246-
device=self.device,
247-
pin_memory=True,
248-
vocab_size=self.model_config.get_vocab_size(),
249-
)
250223

251224
self.input_ids = torch.zeros(self.max_num_tokens,
252225
dtype=torch.int32,
@@ -600,7 +573,10 @@ def _process_reqs(
600573

601574
block_table_indices = (req_indices * self.max_num_blocks_per_req +
602575
positions_np // self.block_size)
603-
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
576+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
577+
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
578+
else:
579+
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
604580
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
605581
block_offsets = positions_np % self.block_size
606582
np.add(block_numbers * self.block_size,
@@ -1206,6 +1182,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
12061182
"""
12071183
import torch_npu
12081184
kv_caches: Dict[str, torch.Tensor] = {}
1185+
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
1186+
self.input_batch = InputBatch(
1187+
max_num_reqs=self.max_num_reqs,
1188+
max_model_len=self.model_config.max_model_len,
1189+
max_num_batched_tokens=self.max_num_tokens,
1190+
device=self.device,
1191+
pin_memory=True,
1192+
vocab_size=self.model_config.get_vocab_size(),
1193+
block_size=self.cache_config.block_size,
1194+
)
12091195

12101196
for kv_cache_group in kv_cache_config.kv_cache_groups:
12111197
kv_cache_spec = kv_cache_group.kv_cache_spec

0 commit comments

Comments
 (0)