|
18 | 18 | from typing import Callable, Optional
|
19 | 19 |
|
20 | 20 | import torch
|
| 21 | +import torch.distributed as dist |
21 | 22 | import torch_npu
|
22 | 23 | from vllm.config import get_current_vllm_config
|
23 |
| -from vllm.distributed import (get_tensor_model_parallel_world_size, |
| 24 | +from vllm.distributed import (GroupCoordinator, |
| 25 | + get_tensor_model_parallel_world_size, |
24 | 26 | tensor_model_parallel_all_reduce)
|
25 | 27 | from vllm.distributed.parallel_state import get_dp_group
|
26 | 28 | from vllm.model_executor.layers.fused_moe.layer import (
|
@@ -154,6 +156,143 @@ def fused_experts_with_mc2(
|
154 | 156 | return hidden_states
|
155 | 157 |
|
156 | 158 |
|
| 159 | +# currently expert parallelism implemented with all2all |
| 160 | +# is under-optimized. |
| 161 | +def fused_experts_with_all2all( |
| 162 | + hidden_states: torch.Tensor, |
| 163 | + w1: torch.Tensor, |
| 164 | + w2: torch.Tensor, |
| 165 | + topk_weights: torch.Tensor, |
| 166 | + topk_ids: torch.Tensor, |
| 167 | + top_k: int, |
| 168 | + expert_map: torch.Tensor = None, |
| 169 | + ep_group: GroupCoordinator = None, |
| 170 | +): |
| 171 | + original_shape = hidden_states.shape |
| 172 | + if len(original_shape) == 3: |
| 173 | + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| 174 | + |
| 175 | + num_tokens, _ = hidden_states.shape |
| 176 | + num_experts = w1.shape[0] |
| 177 | + device = hidden_states.device |
| 178 | + |
| 179 | + if expert_map is not None: |
| 180 | + global_num_experts = len(expert_map) |
| 181 | + local_num_experts = global_num_experts // ep_group.world_size |
| 182 | + row_idx_len = num_tokens * top_k |
| 183 | + row_idx = (torch.arange(0, |
| 184 | + row_idx_len, |
| 185 | + dtype=torch.int32, |
| 186 | + device=device).view(top_k, -1).permute( |
| 187 | + 1, 0).contiguous()) |
| 188 | + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
| 189 | + hidden_states, |
| 190 | + row_idx=row_idx, |
| 191 | + expert_idx=topk_ids, |
| 192 | + active_num=num_tokens) |
| 193 | + |
| 194 | + global_expert_tokens = torch.bincount(expanded_expert_idx, |
| 195 | + minlength=global_num_experts) |
| 196 | + scatter_sizes = global_expert_tokens.view(ep_group.world_size, |
| 197 | + -1).sum(-1) |
| 198 | + |
| 199 | + gather_sizes = torch.empty_like(scatter_sizes) |
| 200 | + dist.all_to_all_single(gather_sizes, |
| 201 | + scatter_sizes, |
| 202 | + group=ep_group.device_group) |
| 203 | + scatter_size_list = scatter_sizes.cpu().tolist() |
| 204 | + gather_size_list = gather_sizes.cpu().tolist() |
| 205 | + |
| 206 | + expanded_expert_idx = expanded_expert_idx % local_num_experts |
| 207 | + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, |
| 208 | + scatter_size_list, |
| 209 | + gather_size_list) |
| 210 | + local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, |
| 211 | + scatter_size_list, |
| 212 | + gather_size_list) |
| 213 | + |
| 214 | + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) |
| 215 | + |
| 216 | + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
| 217 | + sorted_local_expert_idx, local_num_experts).to(torch.int64) |
| 218 | + |
| 219 | + hidden_states = hidden_states[sorted_idx] |
| 220 | + else: |
| 221 | + row_idx_len = num_tokens * top_k |
| 222 | + row_idx = torch.arange(0, |
| 223 | + row_idx_len, |
| 224 | + dtype=torch.int32, |
| 225 | + device=topk_weights.device).view( |
| 226 | + top_k, -1).permute(1, 0).contiguous() |
| 227 | + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
| 228 | + hidden_states, |
| 229 | + row_idx=row_idx, |
| 230 | + expert_idx=topk_ids, |
| 231 | + active_num=num_tokens) |
| 232 | + |
| 233 | + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
| 234 | + expanded_expert_idx, num_experts) |
| 235 | + expert_tokens = expert_tokens.to(torch.int64) |
| 236 | + |
| 237 | + w1 = w1.transpose(1, 2) |
| 238 | + gate_up_out_list = torch_npu.npu_grouped_matmul( |
| 239 | + x=[hidden_states], |
| 240 | + weight=[w1], |
| 241 | + split_item=2, |
| 242 | + group_list_type=0, |
| 243 | + group_type=0, |
| 244 | + group_list=expert_tokens, |
| 245 | + ) |
| 246 | + |
| 247 | + # TODO: Remove this in the future. |
| 248 | + hidden_states = torch.cat(gate_up_out_list, dim=0) |
| 249 | + hidden_states = torch_npu.npu_swiglu(hidden_states) |
| 250 | + |
| 251 | + w2 = w2.transpose(1, 2) |
| 252 | + down_out_list = torch_npu.npu_grouped_matmul( |
| 253 | + x=[hidden_states], |
| 254 | + weight=[w2], |
| 255 | + split_item=2, |
| 256 | + group_list_type=0, |
| 257 | + group_type=0, |
| 258 | + group_list=expert_tokens, |
| 259 | + ) |
| 260 | + |
| 261 | + hidden_states = torch.cat(down_out_list, dim=0) |
| 262 | + |
| 263 | + if expert_map is not None: |
| 264 | + resorted_idx = torch.argsort(sorted_idx) |
| 265 | + hidden_states = hidden_states[resorted_idx] |
| 266 | + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, |
| 267 | + gather_size_list, |
| 268 | + scatter_size_list) |
| 269 | + |
| 270 | + final_hidden_states = torch_npu.npu_moe_finalize_routing( |
| 271 | + hidden_states, |
| 272 | + skip1=None, |
| 273 | + skip2=None, |
| 274 | + bias=None, |
| 275 | + scales=topk_weights, |
| 276 | + expanded_src_to_dst_row=expanded_row_idx, |
| 277 | + export_for_source_row=topk_ids, |
| 278 | + ) |
| 279 | + else: |
| 280 | + # TODO: Reorder device memory 2 times here, replace the current |
| 281 | + # implementation here when suitable operators become available. |
| 282 | + final_hidden_states = torch_npu.npu_moe_finalize_routing( |
| 283 | + hidden_states, |
| 284 | + skip1=None, |
| 285 | + skip2=None, |
| 286 | + bias=None, |
| 287 | + scales=topk_weights, |
| 288 | + expanded_src_to_dst_row=expanded_row_idx, |
| 289 | + export_for_source_row=topk_ids, |
| 290 | + ) |
| 291 | + if len(original_shape) == 3: |
| 292 | + final_hidden_states = final_hidden_states.view(original_shape) |
| 293 | + return final_hidden_states |
| 294 | + |
| 295 | + |
157 | 296 | def fused_experts(
|
158 | 297 | hidden_states: torch.Tensor,
|
159 | 298 | w1: torch.Tensor,
|
@@ -494,7 +633,7 @@ def apply(
|
494 | 633 | custom_routing_function: Optional[Callable] = None,
|
495 | 634 | scoring_func: str = "softmax",
|
496 | 635 | e_score_correction_bias: Optional[torch.Tensor] = None,
|
497 |
| - is_prefill=False, |
| 636 | + is_prefill: bool = False, |
498 | 637 | **kwargs,
|
499 | 638 | ):
|
500 | 639 | # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
@@ -536,14 +675,27 @@ def apply(
|
536 | 675 | top_k=top_k,
|
537 | 676 | expert_map=expert_map,
|
538 | 677 | moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
539 |
| - else: |
| 678 | + elif get_ep_group().world_size == 1: |
540 | 679 | return fused_experts(hidden_states=x,
|
541 | 680 | w1=layer.w13_weight,
|
542 | 681 | w2=layer.w2_weight,
|
543 | 682 | topk_weights=topk_weights,
|
544 | 683 | topk_ids=topk_ids,
|
545 | 684 | top_k=top_k,
|
546 | 685 | expert_map=expert_map)
|
| 686 | + else: |
| 687 | + # The current implementation of deepseek moe splits hidden_states |
| 688 | + # according to tp_size before they are feed into fused_moe module. |
| 689 | + # Therefore, all2all is needed no matter how dp/tp is set so as to |
| 690 | + # dispatch/combine tokens. |
| 691 | + return fused_experts_with_all2all(hidden_states=x, |
| 692 | + w1=layer.w13_weight, |
| 693 | + w2=layer.w2_weight, |
| 694 | + topk_weights=topk_weights, |
| 695 | + topk_ids=topk_ids, |
| 696 | + top_k=top_k, |
| 697 | + expert_map=expert_map, |
| 698 | + ep_group=get_ep_group()) |
547 | 699 |
|
548 | 700 |
|
549 | 701 | class AscendFusedMoE(FusedMoE):
|
@@ -721,8 +873,7 @@ def forward(self,
|
721 | 873 | scoring_func=self.scoring_func,
|
722 | 874 | e_score_correction_bias=self.e_score_correction_bias,
|
723 | 875 | is_prefill=is_prefill,
|
724 |
| - enable_force_load_balance=enable_force_load_balance, |
725 |
| - dp_size=self.dp_size) |
| 876 | + enable_force_load_balance=enable_force_load_balance) |
726 | 877 |
|
727 | 878 | if VLLM_ENABLE_MC2 and not is_prefill:
|
728 | 879 | ...
|
|
0 commit comments