Skip to content

Commit 1e67089

Browse files
Angazennangazenn
and
angazenn
authored
[BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant (vllm-project#819)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? 1. This PR introduces native `all_to_all` communication operator to fix `allgather` bugs when dp_size > 1. Besides, it adds a naive implementation of force-load-balance when doing profile runs. 2. The operator `npu_dequant_swiglu_quant` only supports input hidden_states with dtype `torch.int32`. This tensor occupies space of `global_bs * seq_len * topk * hidden_size`, which might be very large as `ep_size` grows. Therefore we need to disable this operator and use original `swiglu` && `quantize`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By performing offline inference: ![image](https://github.com/user-attachments/assets/e003d5dc-0753-41ae-9303-e87f73ac6828) --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
1 parent 68fb634 commit 1e67089

File tree

7 files changed

+313
-76
lines changed

7 files changed

+313
-76
lines changed

vllm_ascend/distributed/communicator.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,62 @@
1414
# limitations under the License.
1515
# This file is a part of the vllm-ascend project.
1616
#
17-
from typing import Optional
17+
from typing import List, Optional
1818

1919
import torch
20-
from torch.distributed import ProcessGroup
20+
import torch.distributed as dist
2121
from vllm.distributed.device_communicators.base_device_communicator import \
2222
DeviceCommunicatorBase
2323

2424

2525
class NPUCommunicator(DeviceCommunicatorBase):
2626

2727
def __init__(self,
28-
cpu_group: ProcessGroup,
28+
cpu_group: dist.ProcessGroup,
2929
device: Optional[torch.device] = None,
30-
device_group: Optional[ProcessGroup] = None,
30+
device_group: Optional[dist.ProcessGroup] = None,
3131
unique_name: str = ""):
3232
super().__init__(cpu_group, device, device_group, unique_name)
3333
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
3434
# init device according to rank
3535
self.device = torch.npu.current_device()
36+
37+
def all_to_all(self,
38+
input_: torch.Tensor,
39+
scatter_dim: int = 0,
40+
gather_dim: int = -1,
41+
scatter_sizes: Optional[List[int]] = None,
42+
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
43+
44+
if scatter_dim < 0:
45+
scatter_dim += input_.dim()
46+
if gather_dim < 0:
47+
gather_dim += input_.dim()
48+
49+
if scatter_sizes is not None and gather_sizes is not None:
50+
input_list = [
51+
t.contiguous()
52+
for t in torch.split(input_, scatter_sizes, scatter_dim)
53+
]
54+
output_list = []
55+
tensor_shape_base = input_list[self.rank].size()
56+
for i in range(self.world_size):
57+
tensor_shape = list(tensor_shape_base)
58+
tensor_shape[gather_dim] = gather_sizes[i]
59+
output_list.append(
60+
torch.empty(tensor_shape,
61+
dtype=input_.dtype,
62+
device=input_.device))
63+
64+
else:
65+
input_list = [
66+
t.contiguous() for t in torch.tensor_split(
67+
input_, self.world_size, scatter_dim)
68+
]
69+
output_list = [
70+
torch.empty_like(input_list[i]) for i in range(self.world_size)
71+
]
72+
73+
dist.all_to_all(output_list, input_list, group=self.device_group)
74+
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
75+
return output_tensor

vllm_ascend/models/deepseek_v2.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -205,50 +205,66 @@ def __init__(
205205
)
206206
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
207207

208-
vllm_config = get_current_vllm_config()
209208
self.dp_size = get_dp_group().world_size
210-
batch_size = vllm_config.scheduler_config.max_num_seqs
211209

212-
params_dtype = torch.get_default_dtype()
213-
self.final_hidden_states = torch.zeros(
214-
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
215210
self.tp_group = get_tp_group().device_group
211+
self.tp_rank = get_tp_group().rank_in_group
216212

217213
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
218214
attn_metadata = get_forward_context().attn_metadata
215+
# when profile runs, force experts to load balanced tokens
216+
# to avoid high memory consumption on a single rank.
217+
# TODO: need a better flag to indicate whether in profile run or not.
219218
if attn_metadata is None:
220219
# for profile run
221220
is_prefill = True
221+
enable_force_load_balance = True
222222
else:
223223
is_prefill = attn_metadata.num_prefills > 0
224+
enable_force_load_balance = False
224225
num_tokens, hidden_dim = hidden_states.shape
225-
hidden_states = hidden_states.view(-1, hidden_dim)
226226

227-
if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
228-
chunks = torch.chunk(hidden_states,
229-
get_tp_group().world_size,
230-
dim=0)
231-
hidden_states = chunks[get_tp_group().rank_in_group]
227+
if self.n_shared_experts is not None:
228+
shared_output = self.shared_experts(hidden_states)
229+
230+
if self.tp_size > 1:
231+
# pass
232+
num_tokens, hidden_size = hidden_states.shape
233+
if num_tokens < self.tp_size:
234+
target_size = self.tp_size
235+
new_hidden_states = torch.empty([target_size, hidden_size],
236+
dtype=hidden_states.dtype,
237+
device=hidden_states.device)
238+
new_hidden_states[:num_tokens] = hidden_states
239+
hidden_states = new_hidden_states
240+
chunk_hidden_states = torch.tensor_split(hidden_states,
241+
self.tp_size,
242+
dim=0)
243+
local_hidden_states = chunk_hidden_states[self.tp_rank]
244+
else:
245+
local_hidden_states = hidden_states
232246

233247
# router_logits: (num_tokens, n_experts)
234-
router_logits, _ = self.gate(hidden_states)
248+
router_logits, _ = self.gate(local_hidden_states)
235249

236-
final_hidden_states = self.experts(
237-
hidden_states=hidden_states,
250+
router_hidden_states = self.experts(
251+
hidden_states=local_hidden_states,
238252
router_logits=router_logits,
239253
is_prefill=is_prefill,
240-
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
254+
top_k=CustomDeepseekV2MoE.top_k,
255+
enable_force_load_balance=enable_force_load_balance,
256+
) * self.routed_scaling_factor
241257

242258
if self.tp_size > 1:
243-
if VLLM_ENABLE_MC2 and not is_prefill:
244-
dist.all_gather_into_tensor(self.final_hidden_states,
245-
final_hidden_states, self.tp_group)
246-
final_hidden_states = self.final_hidden_states
247-
else:
248-
final_hidden_states = tensor_model_parallel_all_reduce(
249-
final_hidden_states)
250-
if self.n_shared_experts is not None:
251-
shared_output = self.shared_experts(hidden_states)
259+
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
260+
self.tp_group)
261+
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
262+
if num_tokens < self.tp_size:
263+
final_hidden_states = final_hidden_states[:num_tokens]
264+
else:
265+
final_hidden_states = router_hidden_states
266+
267+
if shared_output is not None:
252268
final_hidden_states = final_hidden_states + shared_output
253269

254270
return final_hidden_states.view(num_tokens, hidden_dim)

vllm_ascend/ops/fused_moe.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Callable, Optional
1919

2020
import torch
21-
import torch.distributed as dist
2221
import torch_npu
2322
from vllm.config import get_current_vllm_config
2423
from vllm.distributed import tensor_model_parallel_all_reduce
@@ -636,6 +635,7 @@ def forward(self,
636635
hidden_states: torch.Tensor,
637636
router_logits: torch.Tensor,
638637
is_prefill: bool,
638+
enable_force_load_balance: bool = False,
639639
top_k=None):
640640
assert self.quant_method is not None
641641

@@ -644,17 +644,8 @@ def forward(self,
644644
else:
645645
real_top_k = self.top_k
646646

647-
if self.dp_size > 1:
648-
if VLLM_ENABLE_MC2 and not is_prefill:
649-
...
650-
elif USING_LCCL_COM: # type: ignore
651-
hidden_states = get_dp_group().all_gather(
652-
hidden_states, 0, False)
653-
router_logits = get_dp_group().all_gather(
654-
router_logits, 0, False)
655-
else:
656-
hidden_states = get_dp_group().all_gather(hidden_states, 0)
657-
router_logits = get_dp_group().all_gather(router_logits, 0)
647+
if VLLM_ENABLE_MC2 and not is_prefill:
648+
...
658649

659650
# Matrix multiply.
660651
final_hidden_states = self.quant_method.apply(
@@ -671,17 +662,12 @@ def forward(self,
671662
custom_routing_function=self.custom_routing_function,
672663
scoring_func=self.scoring_func,
673664
e_score_correction_bias=self.e_score_correction_bias,
674-
is_prefill=is_prefill)
675-
676-
if self.dp_size > 1:
677-
if VLLM_ENABLE_MC2 and not is_prefill:
678-
...
679-
else:
680-
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
681-
final_hidden_states,
682-
"sum",
683-
scatter_dim=0,
684-
group=get_dp_group().device_group)
665+
is_prefill=is_prefill,
666+
enable_force_load_balance=enable_force_load_balance,
667+
dp_size=self.dp_size)
668+
669+
if VLLM_ENABLE_MC2 and not is_prefill:
670+
...
685671

686672
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
687673
final_hidden_states = tensor_model_parallel_all_reduce(

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
1718
# patch_utils should be the first import, because it will be used by other
1819
# patch files.
1920
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
21+
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
2022
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
2123
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2224
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import List, Optional
19+
20+
import torch
21+
import vllm
22+
from vllm.distributed.parallel_state import GroupCoordinator
23+
24+
25+
class GroupCoordinatorPatch(GroupCoordinator):
26+
27+
def __init__(self, *args, **kwargs):
28+
super().__init__(*args, **kwargs)
29+
30+
def all_to_all(self,
31+
input_: torch.Tensor,
32+
scatter_dim: int = 0,
33+
gather_dim: int = -1,
34+
scatter_sizes: Optional[List[int]] = None,
35+
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
36+
if self.world_size == 1:
37+
return input_
38+
assert -input_.dim() <= scatter_dim < input_.dim(), (
39+
f"Invalid scatter dim ({scatter_dim}) for input tensor with shape {input_.size()}"
40+
)
41+
assert -input_.dim() <= gather_dim < input_.dim(), (
42+
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
43+
)
44+
return self.device_communicator.all_to_all(input_, scatter_dim,
45+
gather_dim, scatter_sizes,
46+
gather_sizes)
47+
48+
49+
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving

vllm_ascend/quantization/quant_config.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,14 +321,15 @@ def apply(
321321
scoring_func: str = "softmax",
322322
e_score_correction_bias: Optional[torch.Tensor] = None,
323323
is_prefill: bool = True,
324+
enable_force_load_balance: bool = False,
325+
dp_size: int = 1,
324326
**kwargs,
325327
) -> torch.Tensor:
326-
return self.quant_method.apply(layer, x, router_logits, top_k,
327-
renormalize, use_grouped_topk,
328-
global_num_experts, expert_map,
329-
topk_group, num_expert_group,
330-
custom_routing_function, scoring_func,
331-
e_score_correction_bias, is_prefill)
328+
return self.quant_method.apply(
329+
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
330+
global_num_experts, expert_map, topk_group, num_expert_group,
331+
custom_routing_function, scoring_func, e_score_correction_bias,
332+
is_prefill, enable_force_load_balance, dp_size)
332333

333334
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
334335
if hasattr(self.quant_method, "process_weights_after_loading"):

0 commit comments

Comments
 (0)