Skip to content

Commit c6ac399

Browse files
authored
[Bugfix] Fix the method of importing environment variables in DeepSee… (vllm-project#817)
### What this PR does / why we need it? Fix the method of importing environment variables in DeepSeek model to support successful compilation via aclgraph. Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 6193ba6 commit c6ac399

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

vllm_ascend/envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
lambda: os.getenv("CMAKE_BUILD_TYPE"),
3535
"COMPILE_CUSTOM_KERNELS":
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
37+
"VLLM_ENABLE_MC2":
38+
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"USING_LCCL_COM":
40+
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
3741
"SOC_VERSION":
3842
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
3943
# If set, vllm-ascend will print verbose logs during compilation

vllm_ascend/models/deepseek_v2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
import os
2928
from typing import Any, Dict, List, Optional, Union
3029

3130
import torch
@@ -66,9 +65,12 @@
6665
maybe_prefix)
6766
from vllm.sequence import IntermediateTensors
6867

68+
import vllm_ascend.envs as envs_ascend
6969
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7070
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7171

72+
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
73+
7274

7375
class CustomDeepseekV2MLP(nn.Module):
7476

@@ -206,7 +208,6 @@ def __init__(
206208
vllm_config = get_current_vllm_config()
207209
self.dp_size = get_dp_group().world_size
208210
batch_size = vllm_config.scheduler_config.max_num_seqs
209-
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
210211

211212
params_dtype = torch.get_default_dtype()
212213
self.final_hidden_states = torch.zeros(
@@ -223,7 +224,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
223224
num_tokens, hidden_dim = hidden_states.shape
224225
hidden_states = hidden_states.view(-1, hidden_dim)
225226

226-
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
227+
if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
227228
chunks = torch.chunk(hidden_states,
228229
get_tp_group().world_size,
229230
dim=0)
@@ -239,7 +240,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
239240
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
240241

241242
if self.tp_size > 1:
242-
if self.enable_mc2 and not is_prefill:
243+
if VLLM_ENABLE_MC2 and not is_prefill:
243244
dist.all_gather_into_tensor(self.final_hidden_states,
244245
final_hidden_states, self.tp_group)
245246
final_hidden_states = self.final_hidden_states

vllm_ascend/ops/fused_moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# This file is a part of the vllm-ascend project.
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

18-
import os
1918
from typing import Callable, Optional
2019

2120
import torch
@@ -29,8 +28,12 @@
2928
from vllm.model_executor.layers.quantization.base_config import \
3029
QuantizeMethodBase
3130

31+
import vllm_ascend.envs as envs_ascend
3232
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
3333

34+
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
35+
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
36+
3437

3538
def fused_experts_with_mc2(
3639
hidden_states: torch.Tensor,
@@ -493,7 +496,7 @@ def apply(
493496
e_score_correction_bias=e_score_correction_bias,
494497
)
495498

496-
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
499+
if VLLM_ENABLE_MC2 and not is_prefill:
497500
return fused_experts_with_mc2(
498501
hidden_states=x,
499502
w1=layer.w13_weight,
@@ -624,11 +627,9 @@ def forward(self,
624627
real_top_k = self.top_k
625628

626629
if self.dp_size > 1:
627-
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
628-
) == 1 and not is_prefill:
630+
if VLLM_ENABLE_MC2 and not is_prefill:
629631
...
630-
elif int(os.environ.get("USING_LCCL_COM",
631-
'0')) == 1: # type: ignore
632+
elif USING_LCCL_COM: # type: ignore
632633
hidden_states = get_dp_group().all_gather(
633634
hidden_states, 0, False)
634635
router_logits = get_dp_group().all_gather(
@@ -655,8 +656,7 @@ def forward(self,
655656
is_prefill=is_prefill)
656657

657658
if self.dp_size > 1:
658-
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
659-
) == 1 and not is_prefill:
659+
if VLLM_ENABLE_MC2 and not is_prefill:
660660
...
661661
else:
662662
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(

0 commit comments

Comments
 (0)