Skip to content

Commit 0eef7e3

Browse files
author
wangxiaoxin (A)
committed
xx
1 parent 25f502e commit 0eef7e3

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

tests/sample/test_sampler.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66

77
from typing import Optional
88

9-
import pytest
109
import torch
1110

12-
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p # noqa: F401
13-
from vllm.v1.sample.sampler import Sampler # noqa: F401
11+
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p # noqa: F401
12+
from vllm.v1.sample.sampler import Sampler # noqa: F401
1413

1514
# Set tolerance to 1 for quant ops
1615
DEFAULT_ATOL = 1e-3
1716
DEFAULT_RTOL = 1e-3
1817

18+
1919
def apply_min_p_new(
2020
logits: torch.Tensor,
2121
min_p: torch.Tensor,
@@ -28,14 +28,13 @@ def apply_min_p_new(
2828
# Convert logits to probability distribution
2929
probability_values = torch.nn.functional.softmax(logits, dim=-1)
3030
# Calculate maximum probabilities per sequence
31-
max_probabilities = torch.amax(probability_values,
32-
dim=-1,
33-
keepdim=True)
31+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
3432
# Reshape min_p for broadcasting
3533
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
3634
# Identify valid tokens using threshold comparison
3735
# Apply mask using boolean indexing
38-
logits = logits.masked_fill(probability_values < adjusted_min_p, -float('inf'))
36+
logits = logits.masked_fill(probability_values < adjusted_min_p,
37+
-float('inf'))
3938
return logits
4039

4140

@@ -46,21 +45,23 @@ def apply_top_k_top_p_new(
4645
) -> torch.Tensor:
4746
batch_size, vocab_size = logits.shape
4847
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
49-
48+
5049
# Apply top-k.
5150
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
5251
top_k_mask = logits_sort < boundary
5352
logits_sort.masked_fill_(top_k_mask, -float("inf"))
54-
55-
53+
5654
if p is not None:
5755
# Apply top-p.
5856
cutoff = top_k_mask.sum(dim=-1).min()
5957
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
6058
probs_sum = probs_sort.cumsum(dim=-1)
6159
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
6260
top_p_mask[:, -1] = True
63-
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
61+
strides = torch.arange(0,
62+
batch_size*vocab_size,
63+
vocab_size,
64+
device=logits.device)
6465
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
6566
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
6667
logits_flatten = logits.flatten()
@@ -69,11 +70,12 @@ def apply_top_k_top_p_new(
6970
logits[valid_idx] = valid_logits
7071
return logits.reshape(batch_size, vocab_size)
7172

73+
7274
# test with leading dimension and merge seqlen and batch_size as num_tokens
7375
@torch.inference_mode()
7476
def test_apply_min_p(
7577
) -> None:
76-
logits = torch.randn((128,7168)).npu()
78+
logits = torch.randn((128, 7168)).npu()
7779
min_p = torch.Tensor([0.01]).npu()
7880
logits_new = apply_min_p_new(logits, min_p)
7981
sampler = Sampler()
@@ -84,11 +86,12 @@ def test_apply_min_p(
8486
atol=DEFAULT_ATOL,
8587
rtol=DEFAULT_RTOL)
8688

89+
8790
# test with leading dimension and merge seqlen and batch_size as num_tokens
8891
@torch.inference_mode()
8992
def test_apply_top_k_top_p(
9093
) -> None:
91-
logits = torch.randn((128,7168)).npu()
94+
logits = torch.randn((128, 7168)).npu()
9295
k = torch.Tensor([-1]).int().npu()
9396
p = torch.Tensor([1]).int().npu()
9497
logits_new = apply_top_k_top_p_new(logits, k, p)

vllm_ascend/patch/worker/patch_common/patch_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _apply_top_k_top_p(
6161
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
6262
probs_sum = probs_sort.cumsum(dim=-1)
6363
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
64-
64+
6565
top_p_mask[:, -1] = True
6666
strides = torch.arange(0,
6767
batch_size * vocab_size,

0 commit comments

Comments
 (0)