Skip to content

Commit 1f9fb86

Browse files
Angazennangazenn
and
angazenn
authored
[BugFix] Fix accuracy bugs for unquantized deepseekv3 models (#897)
### What this PR does / why we need it? This PR fixes two accuracy bugs incurred by PR #819 when running deepseekv3 series models: 1. #819 adds `all_to_all` communication in quantized cases, but `all_gather` && `reduce_scatter` are removed in both of quantized and unquantized cases. When running unquantized deepseekv3 models with `ep_size == world_size`, the moe modules fail to communicate. Therefore, this PR adds `all_to_all` communication on unquantized situation to solve this accuracy issue. 2. Use `ep_size` rather than `dp_size` to decide whether to use `all_to_all` in moe. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
1 parent 17f05b1 commit 1f9fb86

File tree

3 files changed

+162
-9
lines changed

3 files changed

+162
-9
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 156 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from typing import Callable, Optional
1919

2020
import torch
21+
import torch.distributed as dist
2122
import torch_npu
2223
from vllm.config import get_current_vllm_config
23-
from vllm.distributed import (get_tensor_model_parallel_world_size,
24+
from vllm.distributed import (GroupCoordinator,
25+
get_tensor_model_parallel_world_size,
2426
tensor_model_parallel_all_reduce)
2527
from vllm.distributed.parallel_state import get_dp_group
2628
from vllm.model_executor.layers.fused_moe.layer import (
@@ -154,6 +156,143 @@ def fused_experts_with_mc2(
154156
return hidden_states
155157

156158

159+
# currently expert parallelism implemented with all2all
160+
# is under-optimized.
161+
def fused_experts_with_all2all(
162+
hidden_states: torch.Tensor,
163+
w1: torch.Tensor,
164+
w2: torch.Tensor,
165+
topk_weights: torch.Tensor,
166+
topk_ids: torch.Tensor,
167+
top_k: int,
168+
expert_map: torch.Tensor = None,
169+
ep_group: GroupCoordinator = None,
170+
):
171+
original_shape = hidden_states.shape
172+
if len(original_shape) == 3:
173+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
174+
175+
num_tokens, _ = hidden_states.shape
176+
num_experts = w1.shape[0]
177+
device = hidden_states.device
178+
179+
if expert_map is not None:
180+
global_num_experts = len(expert_map)
181+
local_num_experts = global_num_experts // ep_group.world_size
182+
row_idx_len = num_tokens * top_k
183+
row_idx = (torch.arange(0,
184+
row_idx_len,
185+
dtype=torch.int32,
186+
device=device).view(top_k, -1).permute(
187+
1, 0).contiguous())
188+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
189+
hidden_states,
190+
row_idx=row_idx,
191+
expert_idx=topk_ids,
192+
active_num=num_tokens)
193+
194+
global_expert_tokens = torch.bincount(expanded_expert_idx,
195+
minlength=global_num_experts)
196+
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
197+
-1).sum(-1)
198+
199+
gather_sizes = torch.empty_like(scatter_sizes)
200+
dist.all_to_all_single(gather_sizes,
201+
scatter_sizes,
202+
group=ep_group.device_group)
203+
scatter_size_list = scatter_sizes.cpu().tolist()
204+
gather_size_list = gather_sizes.cpu().tolist()
205+
206+
expanded_expert_idx = expanded_expert_idx % local_num_experts
207+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
208+
scatter_size_list,
209+
gather_size_list)
210+
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
211+
scatter_size_list,
212+
gather_size_list)
213+
214+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
215+
216+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
217+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
218+
219+
hidden_states = hidden_states[sorted_idx]
220+
else:
221+
row_idx_len = num_tokens * top_k
222+
row_idx = torch.arange(0,
223+
row_idx_len,
224+
dtype=torch.int32,
225+
device=topk_weights.device).view(
226+
top_k, -1).permute(1, 0).contiguous()
227+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
228+
hidden_states,
229+
row_idx=row_idx,
230+
expert_idx=topk_ids,
231+
active_num=num_tokens)
232+
233+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
234+
expanded_expert_idx, num_experts)
235+
expert_tokens = expert_tokens.to(torch.int64)
236+
237+
w1 = w1.transpose(1, 2)
238+
gate_up_out_list = torch_npu.npu_grouped_matmul(
239+
x=[hidden_states],
240+
weight=[w1],
241+
split_item=2,
242+
group_list_type=0,
243+
group_type=0,
244+
group_list=expert_tokens,
245+
)
246+
247+
# TODO: Remove this in the future.
248+
hidden_states = torch.cat(gate_up_out_list, dim=0)
249+
hidden_states = torch_npu.npu_swiglu(hidden_states)
250+
251+
w2 = w2.transpose(1, 2)
252+
down_out_list = torch_npu.npu_grouped_matmul(
253+
x=[hidden_states],
254+
weight=[w2],
255+
split_item=2,
256+
group_list_type=0,
257+
group_type=0,
258+
group_list=expert_tokens,
259+
)
260+
261+
hidden_states = torch.cat(down_out_list, dim=0)
262+
263+
if expert_map is not None:
264+
resorted_idx = torch.argsort(sorted_idx)
265+
hidden_states = hidden_states[resorted_idx]
266+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
267+
gather_size_list,
268+
scatter_size_list)
269+
270+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
271+
hidden_states,
272+
skip1=None,
273+
skip2=None,
274+
bias=None,
275+
scales=topk_weights,
276+
expanded_src_to_dst_row=expanded_row_idx,
277+
export_for_source_row=topk_ids,
278+
)
279+
else:
280+
# TODO: Reorder device memory 2 times here, replace the current
281+
# implementation here when suitable operators become available.
282+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
283+
hidden_states,
284+
skip1=None,
285+
skip2=None,
286+
bias=None,
287+
scales=topk_weights,
288+
expanded_src_to_dst_row=expanded_row_idx,
289+
export_for_source_row=topk_ids,
290+
)
291+
if len(original_shape) == 3:
292+
final_hidden_states = final_hidden_states.view(original_shape)
293+
return final_hidden_states
294+
295+
157296
def fused_experts(
158297
hidden_states: torch.Tensor,
159298
w1: torch.Tensor,
@@ -494,7 +633,7 @@ def apply(
494633
custom_routing_function: Optional[Callable] = None,
495634
scoring_func: str = "softmax",
496635
e_score_correction_bias: Optional[torch.Tensor] = None,
497-
is_prefill=False,
636+
is_prefill: bool = False,
498637
**kwargs,
499638
):
500639
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
@@ -536,14 +675,27 @@ def apply(
536675
top_k=top_k,
537676
expert_map=expert_map,
538677
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
539-
else:
678+
elif get_ep_group().world_size == 1:
540679
return fused_experts(hidden_states=x,
541680
w1=layer.w13_weight,
542681
w2=layer.w2_weight,
543682
topk_weights=topk_weights,
544683
topk_ids=topk_ids,
545684
top_k=top_k,
546685
expert_map=expert_map)
686+
else:
687+
# The current implementation of deepseek moe splits hidden_states
688+
# according to tp_size before they are feed into fused_moe module.
689+
# Therefore, all2all is needed no matter how dp/tp is set so as to
690+
# dispatch/combine tokens.
691+
return fused_experts_with_all2all(hidden_states=x,
692+
w1=layer.w13_weight,
693+
w2=layer.w2_weight,
694+
topk_weights=topk_weights,
695+
topk_ids=topk_ids,
696+
top_k=top_k,
697+
expert_map=expert_map,
698+
ep_group=get_ep_group())
547699

548700

549701
class AscendFusedMoE(FusedMoE):
@@ -721,8 +873,7 @@ def forward(self,
721873
scoring_func=self.scoring_func,
722874
e_score_correction_bias=self.e_score_correction_bias,
723875
is_prefill=is_prefill,
724-
enable_force_load_balance=enable_force_load_balance,
725-
dp_size=self.dp_size)
876+
enable_force_load_balance=enable_force_load_balance)
726877

727878
if VLLM_ENABLE_MC2 and not is_prefill:
728879
...

vllm_ascend/quantization/quant_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,14 +323,13 @@ def apply(
323323
e_score_correction_bias: Optional[torch.Tensor] = None,
324324
is_prefill: bool = True,
325325
enable_force_load_balance: bool = False,
326-
dp_size: int = 1,
327326
**kwargs,
328327
) -> torch.Tensor:
329328
return self.quant_method.apply(
330329
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
331330
global_num_experts, expert_map, topk_group, num_expert_group,
332331
custom_routing_function, scoring_func, e_score_correction_bias,
333-
is_prefill, enable_force_load_balance, dp_size)
332+
is_prefill, enable_force_load_balance)
334333

335334
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
336335
if hasattr(self.quant_method, "process_weights_after_loading"):

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,6 @@ def apply(
582582
e_score_correction_bias: Optional[torch.Tensor] = None,
583583
is_prefill: bool = True,
584584
enable_force_load_balance: bool = True,
585-
dp_size: int = 1,
586585
**kwargs,
587586
) -> torch.Tensor:
588587
assert router_logits.shape[
@@ -635,7 +634,7 @@ def apply(
635634
top_k=top_k,
636635
expert_map=expert_map,
637636
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
638-
elif dp_size == 1:
637+
elif self.ep_group.world_size == 1:
639638
return fused_experts(hidden_states=x,
640639
w1=layer.w13_weight,
641640
w1_scale=layer.w13_weight_scale,
@@ -646,6 +645,10 @@ def apply(
646645
top_k=top_k,
647646
expert_map=expert_map)
648647
else:
648+
# The current implementation of deepseek moe splits hidden_states
649+
# according to tp_size before they are feed into fused_moe module.
650+
# Therefore, all2all is needed no matter how dp/tp is set so as to
651+
# dispatch/combine tokens.
649652
return fused_experts_with_all2all(hidden_states=x,
650653
w1=layer.w13_weight,
651654
w1_scale=layer.w13_weight_scale,

0 commit comments

Comments
 (0)