Skip to content

Commit 8c61b3a

Browse files
author
wangxiaoxin (A)
committed
add optimze of dsv3.
1 parent 01e3d59 commit 8c61b3a

File tree

6 files changed

+209
-3
lines changed

6 files changed

+209
-3
lines changed

tests/sample/test_sampler.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2023 The vLLM team.
2+
3+
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
4+
# Adapted from
5+
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p # noqa: F401
12+
from vllm.v1.sample.sampler import Sampler # noqa: F401
13+
14+
# Set tolerance to 1 for quant ops
15+
DEFAULT_ATOL = 1e-3
16+
DEFAULT_RTOL = 1e-3
17+
18+
19+
def apply_min_p_new(
20+
logits: torch.Tensor,
21+
min_p: torch.Tensor,
22+
) -> torch.Tensor:
23+
"""
24+
Filters logits using adaptive probability thresholding.
25+
"""
26+
if min_p == 0:
27+
return logits
28+
# Convert logits to probability distribution
29+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
30+
# Calculate maximum probabilities per sequence
31+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
32+
# Reshape min_p for broadcasting
33+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
34+
# Identify valid tokens using threshold comparison
35+
# Apply mask using boolean indexing
36+
logits = logits.masked_fill(probability_values < adjusted_min_p,
37+
-float('inf'))
38+
return logits
39+
40+
41+
def apply_top_k_top_p_new(
42+
logits: torch.Tensor,
43+
k: Optional[torch.Tensor],
44+
p: Optional[torch.Tensor],
45+
) -> torch.Tensor:
46+
batch_size, vocab_size = logits.shape
47+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
48+
49+
# Apply top-k.
50+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
51+
top_k_mask = logits_sort < boundary
52+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
53+
54+
if p is not None:
55+
# Apply top-p.
56+
cutoff = top_k_mask.sum(dim=-1).min()
57+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
58+
probs_sum = probs_sort.cumsum(dim=-1)
59+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
60+
top_p_mask[:, -1] = True
61+
strides = torch.arange(0,
62+
batch_size*vocab_size,
63+
vocab_size,
64+
device=logits.device)
65+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
66+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
67+
logits_flatten = logits.flatten()
68+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
69+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
70+
logits[valid_idx] = valid_logits
71+
return logits.reshape(batch_size, vocab_size)
72+
73+
74+
# test with leading dimension and merge seqlen and batch_size as num_tokens
75+
@torch.inference_mode()
76+
def test_apply_min_p(
77+
) -> None:
78+
logits = torch.randn((128, 7168)).npu()
79+
min_p = torch.Tensor([0.01]).npu()
80+
logits_new = apply_min_p_new(logits, min_p)
81+
sampler = Sampler()
82+
logits_old = sampler.apply_min_p(logits, min_p)
83+
# Compare the results.
84+
torch.testing.assert_close(logits_new,
85+
logits_old,
86+
atol=DEFAULT_ATOL,
87+
rtol=DEFAULT_RTOL)
88+
89+
90+
# test with leading dimension and merge seqlen and batch_size as num_tokens
91+
@torch.inference_mode()
92+
def test_apply_top_k_top_p(
93+
) -> None:
94+
logits = torch.randn((128, 7168)).npu()
95+
k = torch.Tensor([-1]).int().npu()
96+
p = torch.Tensor([1]).int().npu()
97+
logits_new = apply_top_k_top_p_new(logits, k, p)
98+
logits_old = apply_top_k_top_p(logits, k, p)
99+
# Compare the results.
100+
torch.testing.assert_close(logits_new,
101+
logits_old,
102+
atol=DEFAULT_ATOL,
103+
rtol=DEFAULT_RTOL)

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_ENABLE_TOPK_OPTIMZE":
40+
lambda: bool(int(os.getenv("VLLM_ENABLE_TOPK_OPTIMZE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224224
enable_force_load_balance = False
225225
num_tokens, hidden_dim = hidden_states.shape
226226

227-
if self.n_shared_experts is not None:
228-
shared_output = self.shared_experts(hidden_states)
227+
old_hidden_states = hidden_states.detach()
229228

230229
if self.tp_size > 1:
231230
# pass
@@ -264,6 +263,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
264263
else:
265264
final_hidden_states = router_hidden_states
266265

266+
if self.n_shared_experts is not None:
267+
shared_output = self.shared_experts(old_hidden_states)
268+
267269
if shared_output is not None:
268270
final_hidden_states = final_hidden_states + shared_output
269271

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def fused_experts(
371371
num_experts)).to(topk_ids.dtype)
372372

373373
# Sort by local expert IDs
374-
sort_indices = torch.argsort(filtered_experts)
374+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
375375
sorted_token_indices = token_indices[sort_indices]
376376
sorted_weights = filtered_weights[sort_indices]
377377

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2525
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
2626
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
27+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import torch
19+
from vllm.v1.sample.sampler import Sampler
20+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
21+
from vllm_ascend import envs
22+
from typing import Optional
23+
24+
25+
def apply_min_p(
26+
self,
27+
logits: torch.Tensor,
28+
min_p: torch.Tensor,
29+
) -> torch.Tensor:
30+
"""
31+
Filters logits using adaptive probability thresholding.
32+
"""
33+
# Convert logits to probability distribution
34+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
35+
# Calculate maximum probabilities per sequence
36+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
37+
# Reshape min_p for broadcasting
38+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
39+
# Identify valid tokens using threshold comparison
40+
# Apply mask using boolean indexing
41+
logits = logits.masked_fill(probability_values < adjusted_min_p,
42+
-float('inf'))
43+
return logits
44+
45+
46+
def _apply_top_k_top_p(
47+
logits: torch.Tensor,
48+
p: torch.Tensor,
49+
k: torch.Tensor,
50+
) -> torch.Tensor:
51+
batch_size, vocab_size = logits.shape
52+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
53+
54+
# Apply top-k.
55+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
56+
top_k_mask = logits_sort < boundary
57+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
58+
59+
# Apply top-p.
60+
cutoff = top_k_mask.sum(dim=-1).min()
61+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
62+
probs_sum = probs_sort.cumsum(dim=-1)
63+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
64+
65+
top_p_mask[:, -1] = True
66+
strides = torch.arange(0,
67+
batch_size * vocab_size,
68+
vocab_size,
69+
device=logits.device)
70+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
71+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
72+
logits_flatten = logits.flatten()
73+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
74+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
75+
logits[valid_idx] = valid_logits
76+
return logits.reshape(batch_size, vocab_size)
77+
78+
79+
def topk_topp_forward_native(
80+
self,
81+
logits: torch.Tensor,
82+
generators: dict[int, torch.Generator],
83+
k: Optional[torch.Tensor],
84+
p: Optional[torch.Tensor],
85+
) -> torch.Tensor:
86+
"""
87+
PyTorch-native implementation of top-k and top-p sampling.
88+
89+
The logits tensor may be updated in-place.
90+
"""
91+
logits = _apply_top_k_top_p(logits, k, p)
92+
probs = logits.softmax(dim=-1, dtype=torch.float32)
93+
return random_sample(probs, generators)
94+
95+
96+
Sampler.apply_min_p = apply_min_p
97+
if envs.VLLM_ENABLE_TOPK_OPTIMZE:
98+
TopKTopPSampler.forward_native = topk_topp_forward_native

0 commit comments

Comments
 (0)