Skip to content

Commit 9f5ab59

Browse files
Angazennangazenn
and
angazenn
authored
[WIP][BugFix]Fix accuracy issues caused by wrong etp_size passed into FusedMoEParallelConfig when using vLLM 0.9.0 (#961)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? This PR fix accuracy issues incurred by codes that adapt to `FusedMoEParallelConfig` in vLLM 0.9.0 version. The `tp_size` used to split weights are wrongly passed. The root cause is that vLLM community and vLLM-Ascend are using different methods to decide whether to use Expert Parallel. vLLM: vLLM use a flag `enable_expert_parallel` to indicate whether to use EP and use the following codes to decide `ep_size`: ``` use_ep = (dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: return FusedMoEParallelConfig(tp_size=tp_size, tp_rank=tp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, ep_rank=0, use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank return FusedMoEParallelConfig(tp_size=1, tp_rank=0, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, use_ep=True) ``` vLLM-Ascend: vLLM-Ascend uses `etp` to specify Tensor Parallel in MoE. ``` self.ep_size = get_ep_group().world_size self.tp_size = get_etp_group().world_size self.dp_size = (dp_size if dp_size is not None else get_dp_group().world_size) ``` So there will be conflicts if we simply combine these codes together. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
1 parent 01e3d59 commit 9f5ab59

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,7 @@ def __init__(
748748
vllm_parallel_config=vllm_config.parallel_config))
749749

750750
self.moe_parallel_config.ep_size = get_ep_group().world_size
751+
self.moe_parallel_config.tp_size = get_etp_group().world_size
751752

752753
self.top_k = top_k
753754
self.num_experts = num_experts

0 commit comments

Comments
 (0)