Skip to content

Commit 7a325b2

Browse files
authored
[Bugfix][Model] Fix fusedmoe and make modelrunner_v1 compatible with latest vllm (#867)
### What this PR does / why we need it? this PR fix CI failure broken by vllm. 1. add moe_config for fused_moe 2. adjust the change for kv cache group from vllm. currently vllm-ascend doesn't support this feature. this is just a quick fix for backward compatibility fix: #872 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent fd515cd commit 7a325b2

File tree

4 files changed

+138
-80
lines changed

4 files changed

+138
-80
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.v1.worker.gpu_input_batch import InputBatch
3131

3232
from vllm_ascend.ops.attention import vanilla_chunked_prefill
33+
from vllm_ascend.utils import vllm_version_is
3334

3435

3536
class AscendAttentionBackend(AttentionBackend):
@@ -140,8 +141,15 @@ def reorder_batch(self, input_batch: "InputBatch",
140141

141142
def build(self, num_reqs, num_actual_tokens, max_query_len,
142143
common_prefix_len):
143-
block_table = (
144-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
144+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
145+
block_table = (self.runner.input_batch.block_table.
146+
get_device_tensor()[:num_reqs])
147+
else:
148+
block_table = self.runner.input_batch.block_table[
149+
0].get_device_tensor()
150+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
151+
block_table[:num_reqs])
152+
145153
query_lens = self.runner.query_lens
146154
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
147155
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(

vllm_ascend/attention/mla_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1818
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
19+
from vllm_ascend.utils import vllm_version_is
1920
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2021

2122
if TYPE_CHECKING:
@@ -238,8 +239,12 @@ def build(self,
238239
# function. We should avoid GPU -> CPU sync as much as possible because
239240
# it blocks on all previous kernels.
240241
device = self.runner.device
241-
block_table = (
242-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
242+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
243+
block_table = (self.runner.input_batch.block_table.
244+
get_device_tensor()[:num_reqs])
245+
else:
246+
block_table = (self.runner.input_batch.block_table[0].
247+
get_device_tensor()[:num_reqs])
243248
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
244249
device, non_blocking=True)
245250
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
@@ -795,4 +800,4 @@ def forward(
795800
output[:num_decode_tokens] = self._forward_decode(
796801
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
797802
kv_cache, attn_metadata)
798-
return output_padded
803+
return output_padded

vllm_ascend/ops/fused_moe.py

Lines changed: 103 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,22 @@
2020
import torch
2121
import torch_npu
2222
from vllm.config import get_current_vllm_config
23-
from vllm.distributed import tensor_model_parallel_all_reduce
23+
from vllm.distributed import (get_tensor_model_parallel_world_size,
24+
tensor_model_parallel_all_reduce)
2425
from vllm.distributed.parallel_state import get_dp_group
2526
from vllm.model_executor.layers.fused_moe.layer import (
2627
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
27-
from vllm.model_executor.layers.quantization.base_config import \
28-
QuantizeMethodBase
28+
29+
from vllm_ascend.utils import vllm_version_is
30+
31+
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
32+
from vllm.model_executor.layers.fused_moe.layer import (
33+
FusedMoEParallelConfig, MoEConfig)
34+
else:
35+
MoEConfig = None
36+
37+
from vllm.model_executor.layers.quantization.base_config import (
38+
QuantizationConfig, QuantizeMethodBase)
2939

3040
import vllm_ascend.envs as envs_ascend
3141
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
@@ -437,8 +447,11 @@ def select_experts(
437447

438448
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
439449

440-
def __init__(self):
441-
super().__init__()
450+
def __init__(self, moe: MoEConfig = None):
451+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
452+
super().__init__()
453+
else:
454+
super().__init__(moe=moe)
442455
vllm_config = get_current_vllm_config()
443456

444457
ep_group = get_ep_group()
@@ -535,37 +548,54 @@ def apply(
535548

536549
class AscendFusedMoE(FusedMoE):
537550

538-
def __init__(self,
539-
num_experts,
540-
top_k,
541-
hidden_size,
542-
intermediate_size,
543-
params_dtype=None,
544-
reduce_results=False,
545-
renormalize=True,
546-
use_grouped_topk=False,
547-
num_expert_group=None,
548-
topk_group=None,
549-
quant_config=None,
550-
tp_size=None,
551-
ep_size=None,
552-
dp_size=None,
553-
prefix="",
554-
custom_routing_function=None,
555-
scoring_func="softmax",
556-
e_score_correction_bias=None,
557-
activation="silu"):
551+
def __init__(
552+
self,
553+
num_experts: int, # Global number of experts
554+
top_k: int,
555+
hidden_size: int,
556+
intermediate_size: int,
557+
params_dtype: Optional[torch.dtype] = None,
558+
reduce_results: bool = False,
559+
renormalize: bool = True,
560+
use_grouped_topk: bool = False,
561+
num_expert_group: Optional[int] = None,
562+
topk_group: Optional[int] = None,
563+
quant_config: Optional[QuantizationConfig] = None,
564+
tp_size: Optional[int] = None,
565+
ep_size: Optional[int] = None,
566+
dp_size: Optional[int] = None,
567+
prefix: str = "",
568+
custom_routing_function: Optional[Callable] = None,
569+
scoring_func: str = "softmax",
570+
e_score_correction_bias: Optional[torch.Tensor] = None,
571+
activation: str = "silu",
572+
apply_router_weight_on_input: bool = False,
573+
):
574+
# TODO: This could not initialize FusedMoE baseclass,
575+
# fixme and make __init__() of AscendFusedMoE more clear
558576
super(FusedMoE, self).__init__()
559577

560578
if params_dtype is None:
561579
params_dtype = torch.get_default_dtype()
562580

563-
self.ep_size = get_ep_group().world_size
564-
self.tp_size = get_etp_group().world_size
565-
self.dp_size = (dp_size
566-
if dp_size is not None else get_dp_group().world_size)
567-
self.dp_rank = (0
568-
if self.dp_size == 1 else get_dp_group().rank_in_group)
581+
vllm_config = get_current_vllm_config()
582+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
583+
self.ep_size = get_ep_group().world_size
584+
self.tp_size = get_etp_group().world_size
585+
self.dp_size = (dp_size if dp_size is not None else
586+
get_dp_group().world_size)
587+
self.dp_rank = (0 if self.dp_size == 1 else
588+
get_dp_group().rank_in_group)
589+
else:
590+
self.moe_parallel_config: FusedMoEParallelConfig = (
591+
FusedMoEParallelConfig.make(
592+
tp_size_=(tp_size if tp_size is not None else
593+
get_tensor_model_parallel_world_size()),
594+
dp_size_=(dp_size if dp_size is not None else
595+
get_dp_group().world_size),
596+
vllm_parallel_config=vllm_config.parallel_config))
597+
598+
self.moe_parallel_config.ep_size = get_ep_group().world_size
569599

570600
self.top_k = top_k
571601
self.num_experts = num_experts
@@ -590,27 +620,55 @@ def __init__(self,
590620
self.local_num_experts, self.expert_map = determine_expert_map(
591621
self.ep_size,
592622
get_ep_group().rank_in_group, self.global_num_experts)
593-
self.tp_rank = get_etp_group().rank_in_group
594-
self.ep_rank = get_ep_group().rank_in_group
623+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
624+
self.tp_rank = get_etp_group().rank_in_group
625+
self.ep_rank = get_ep_group().rank_in_group
626+
else:
627+
self.moe_parallel_config.tp_rank = get_etp_group(
628+
).rank_in_group
629+
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
630+
595631
else:
596632
# Adjust TP size for DP attention
597633
# haven't test its functionality yet, may remove in the future
598-
self.tp_rank = self.tp_size * self.dp_rank
599-
self.ep_rank = 0
600-
self.tp_size = self.tp_size * self.dp_size
601-
self.ep_size = 1
602-
self.local_num_experts = self.global_num_experts
603-
self.expert_map = None
604-
634+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
635+
self.tp_rank = self.tp_size * self.dp_rank
636+
self.ep_rank = 0
637+
self.tp_size = self.tp_size * self.dp_size
638+
self.ep_size = 1
639+
else:
640+
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
641+
self.moe_parallel_config.ep_rank = 0
642+
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
643+
self.moe_parallel_config.ep_size = 1
644+
645+
self.local_num_experts, self.expert_map = (self.global_num_experts,
646+
None)
605647
if self.scoring_func != "softmax" and not self.use_grouped_topk:
606648
raise ValueError("Only softmax scoring function is supported for "
607649
"non-grouped topk.")
608-
609-
if quant_config is None:
610-
self.quant_method: Optional[QuantizeMethodBase] = (
611-
AscendUnquantizedFusedMoEMethod())
650+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
651+
if quant_config is None:
652+
self.quant_method: Optional[QuantizeMethodBase] = (
653+
AscendUnquantizedFusedMoEMethod())
654+
else:
655+
self.quant_method = quant_config.get_quant_method(self, prefix)
612656
else:
613-
self.quant_method = quant_config.get_quant_method(self, prefix)
657+
moe = MoEConfig(
658+
num_experts=self.global_num_experts,
659+
experts_per_token=top_k,
660+
hidden_dim=hidden_size,
661+
num_local_experts=self.local_num_experts,
662+
moe_parallel_config=self.moe_parallel_config,
663+
# TODO (bnell): this needs to be fixed for quantized types.
664+
in_dtype=params_dtype,
665+
)
666+
667+
if quant_config is None:
668+
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
669+
else:
670+
self.quant_method = quant_config.get_quant_method(self, prefix)
671+
614672
assert self.quant_method is not None
615673

616674
local_num_experts = torch.sum(self.expert_map != -1) \

vllm_ascend/worker/model_runner_v1.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
111111
self.scheduler_config = vllm_config.scheduler_config
112112
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
113113
self.device = device
114+
114115
self.is_multimodal_model = self.model_config.is_multimodal_model
115116
self.block_size = vllm_config.cache_config.block_size
117+
116118
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
117119
self.block_size)
118120
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
@@ -155,24 +157,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
155157
raise NotImplementedError(
156158
"Non-Attention backend is not supported by V1 NPUModelRunner.")
157159

158-
self.attn_backend = get_attn_backend(
159-
self.head_size,
160-
self.dtype,
161-
self.kv_cache_dtype,
162-
self.block_size,
163-
self.model_config.is_attention_free,
164-
use_mla=self.model_config.use_mla,
165-
)
166-
if self.attn_backend is None:
167-
error_msg = (
168-
f"Error with get_att_backend: {self.head_size=}, "
169-
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
170-
f"{self.model_config.is_attention_free=}, "
171-
f"{self.model_config.use_mla=}")
172-
logger.error(error_msg)
173-
raise NotImplementedError(
174-
"Non-Attention backend is not supported by V1 GPUModelRunner.")
175-
176160
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
177161
weakref.proxy(self))
178162

@@ -205,17 +189,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
205189
pin_memory=True,
206190
vocab_size=self.model_config.get_vocab_size(),
207191
)
208-
else:
209-
self.input_batch = InputBatch(
210-
max_num_reqs=self.max_num_reqs,
211-
max_model_len=self.model_config.max_model_len,
212-
max_num_blocks_per_req=self.max_num_blocks_per_req,
213-
max_num_batched_tokens=self.max_num_tokens,
214-
device=self.device,
215-
pin_memory=True,
216-
vocab_size=self.model_config.get_vocab_size(),
217-
)
218-
219192
self.input_ids = torch.zeros(self.max_num_tokens,
220193
dtype=torch.int32,
221194
device=self.device)
@@ -562,7 +535,10 @@ def _process_reqs(
562535

563536
block_table_indices = (req_indices * self.max_num_blocks_per_req +
564537
positions_np // self.block_size)
565-
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
538+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
539+
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
540+
else:
541+
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
566542
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
567543
block_offsets = positions_np % self.block_size
568544
np.add(block_numbers * self.block_size,
@@ -976,6 +952,17 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
976952
"""
977953
import torch_npu
978954
kv_caches: Dict[str, torch.Tensor] = {}
955+
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
956+
self.input_batch = InputBatch(
957+
max_num_reqs=self.max_num_reqs,
958+
max_model_len=self.model_config.max_model_len,
959+
max_num_batched_tokens=self.max_num_tokens,
960+
device=self.device,
961+
pin_memory=True,
962+
vocab_size=self.model_config.get_vocab_size(),
963+
kv_cache_config=kv_cache_config,
964+
)
965+
979966
for kv_cache_group in kv_cache_config.kv_cache_groups:
980967
kv_cache_spec = kv_cache_group.kv_cache_spec
981968
for layer_name in kv_cache_group.layer_names:

0 commit comments

Comments
 (0)