6
6
7
7
from typing import Optional
8
8
9
- import pytest
10
9
import torch
11
10
12
- from vllm .v1 .sample .ops .topk_topp_sampler import apply_top_k_top_p # noqa: F401
13
- from vllm .v1 .sample .sampler import Sampler # noqa: F401
11
+ from vllm .v1 .sample .ops .topk_topp_sampler import apply_top_k_top_p # noqa: F401
12
+ from vllm .v1 .sample .sampler import Sampler # noqa: F401
14
13
15
14
# Set tolerance to 1 for quant ops
16
15
DEFAULT_ATOL = 1e-3
17
16
DEFAULT_RTOL = 1e-3
18
17
18
+
19
19
def apply_min_p_new (
20
20
logits : torch .Tensor ,
21
21
min_p : torch .Tensor ,
@@ -28,14 +28,13 @@ def apply_min_p_new(
28
28
# Convert logits to probability distribution
29
29
probability_values = torch .nn .functional .softmax (logits , dim = - 1 )
30
30
# Calculate maximum probabilities per sequence
31
- max_probabilities = torch .amax (probability_values ,
32
- dim = - 1 ,
33
- keepdim = True )
31
+ max_probabilities = torch .amax (probability_values , dim = - 1 , keepdim = True )
34
32
# Reshape min_p for broadcasting
35
33
adjusted_min_p = min_p .unsqueeze (1 ) * max_probabilities
36
34
# Identify valid tokens using threshold comparison
37
35
# Apply mask using boolean indexing
38
- logits = logits .masked_fill (probability_values < adjusted_min_p , - float ('inf' ))
36
+ logits = logits .masked_fill (probability_values < adjusted_min_p ,
37
+ - float ('inf' ))
39
38
return logits
40
39
41
40
@@ -46,21 +45,23 @@ def apply_top_k_top_p_new(
46
45
) -> torch .Tensor :
47
46
batch_size , vocab_size = logits .shape
48
47
logits_sort , logits_idx = logits .sort (dim = - 1 , descending = False )
49
-
48
+
50
49
# Apply top-k.
51
50
boundary = logits_sort .gather (1 , (vocab_size - k ).unsqueeze (dim = 1 ))
52
51
top_k_mask = logits_sort < boundary
53
52
logits_sort .masked_fill_ (top_k_mask , - float ("inf" ))
54
-
55
-
53
+
56
54
if p is not None :
57
55
# Apply top-p.
58
56
cutoff = top_k_mask .sum (dim = - 1 ).min ()
59
57
probs_sort = logits_sort .softmax (dim = - 1 )[:, cutoff :]
60
58
probs_sum = probs_sort .cumsum (dim = - 1 )
61
59
top_p_mask = probs_sum > 1 - p .unsqueeze (dim = 1 )
62
60
top_p_mask [:, - 1 ] = True
63
- strides = torch .arange (0 , batch_size * vocab_size , vocab_size , device = logits .device )
61
+ strides = torch .arange (0 ,
62
+ batch_size * vocab_size ,
63
+ vocab_size ,
64
+ device = logits .device )
64
65
flatten_idx = logits_idx [:, cutoff :] + strides .unsqueeze (dim = 1 )
65
66
valid_idx = torch .masked_select (flatten_idx , top_p_mask )
66
67
logits_flatten = logits .flatten ()
@@ -69,11 +70,12 @@ def apply_top_k_top_p_new(
69
70
logits [valid_idx ] = valid_logits
70
71
return logits .reshape (batch_size , vocab_size )
71
72
73
+
72
74
# test with leading dimension and merge seqlen and batch_size as num_tokens
73
75
@torch .inference_mode ()
74
76
def test_apply_min_p (
75
77
) -> None :
76
- logits = torch .randn ((128 ,7168 )).npu ()
78
+ logits = torch .randn ((128 , 7168 )).npu ()
77
79
min_p = torch .Tensor ([0.01 ]).npu ()
78
80
logits_new = apply_min_p_new (logits , min_p )
79
81
sampler = Sampler ()
@@ -84,11 +86,12 @@ def test_apply_min_p(
84
86
atol = DEFAULT_ATOL ,
85
87
rtol = DEFAULT_RTOL )
86
88
89
+
87
90
# test with leading dimension and merge seqlen and batch_size as num_tokens
88
91
@torch .inference_mode ()
89
92
def test_apply_top_k_top_p (
90
93
) -> None :
91
- logits = torch .randn ((128 ,7168 )).npu ()
94
+ logits = torch .randn ((128 , 7168 )).npu ()
92
95
k = torch .Tensor ([- 1 ]).int ().npu ()
93
96
p = torch .Tensor ([1 ]).int ().npu ()
94
97
logits_new = apply_top_k_top_p_new (logits , k , p )
0 commit comments