Skip to content

Commit c6798ba

Browse files
authored
Change top_k to be disabled with 0 (still accept -1 for now) (#17773)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 5b2dcbf commit c6798ba

File tree

6 files changed

+14
-13
lines changed

6 files changed

+14
-13
lines changed

tests/samplers/test_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def test_sampler_mixed(seed: int, device: str):
478478
sampling_params = SamplingParams(
479479
temperature=random.random() + 0.1,
480480
top_p=min(random.random() + 0.1, 1),
481-
top_k=random.randint(0, 10) or -1,
481+
top_k=random.randint(0, 10),
482482
n=n,
483483
presence_penalty=random.randint(0, 1),
484484
)

vllm/entrypoints/openai/protocol.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
409409
"repetition_penalty": 1.0,
410410
"temperature": 1.0,
411411
"top_p": 1.0,
412-
"top_k": -1,
412+
"top_k": 0,
413413
"min_p": 0.0,
414414
}
415415

@@ -853,7 +853,7 @@ class CompletionRequest(OpenAIBaseModel):
853853
"repetition_penalty": 1.0,
854854
"temperature": 1.0,
855855
"top_p": 1.0,
856-
"top_k": -1,
856+
"top_k": 0,
857857
"min_p": 0.0,
858858
}
859859

@@ -1679,7 +1679,7 @@ class TranscriptionRequest(OpenAIBaseModel):
16791679
"repetition_penalty": 1.0,
16801680
"temperature": 1.0,
16811681
"top_p": 1.0,
1682-
"top_k": -1,
1682+
"top_k": 0,
16831683
"min_p": 0.0,
16841684
}
16851685

vllm/model_executor/sampling_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def from_sampling_metadata(
416416

417417
# k should not be greater than the vocab size.
418418
top_k = min(sampling_params.top_k, vocab_size)
419-
top_k = vocab_size if top_k == -1 else top_k
419+
top_k = vocab_size if top_k < 1 else top_k
420420
if temperature < _SAMPLING_EPS:
421421
# NOTE: Zero temperature means deterministic sampling
422422
# (i.e., greedy sampling or beam search).

vllm/sampling_params.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class SamplingParams(
149149
top_p: Float that controls the cumulative probability of the top tokens
150150
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
151151
top_k: Integer that controls the number of top tokens to consider. Set
152-
to -1 to consider all tokens.
152+
to 0 (or -1) to consider all tokens.
153153
min_p: Float that represents the minimum probability for a token to be
154154
considered, relative to the probability of the most likely token.
155155
Must be in [0, 1]. Set to 0 to disable this.
@@ -209,7 +209,7 @@ class SamplingParams(
209209
repetition_penalty: float = 1.0
210210
temperature: float = 1.0
211211
top_p: float = 1.0
212-
top_k: int = -1
212+
top_k: int = 0
213213
min_p: float = 0.0
214214
seed: Optional[int] = None
215215
stop: Optional[Union[str, list[str]]] = None
@@ -256,7 +256,7 @@ def from_optional(
256256
repetition_penalty: Optional[float] = 1.0,
257257
temperature: Optional[float] = 1.0,
258258
top_p: Optional[float] = 1.0,
259-
top_k: int = -1,
259+
top_k: int = 0,
260260
min_p: float = 0.0,
261261
seed: Optional[int] = None,
262262
stop: Optional[Union[str, list[str]]] = None,
@@ -376,7 +376,7 @@ def __post_init__(self) -> None:
376376
if self.temperature < _SAMPLING_EPS:
377377
# Zero temperature means greedy sampling.
378378
self.top_p = 1.0
379-
self.top_k = -1
379+
self.top_k = 0
380380
self.min_p = 0.0
381381
self._verify_greedy_sampling()
382382

@@ -404,8 +404,9 @@ def _verify_args(self) -> None:
404404
f"temperature must be non-negative, got {self.temperature}.")
405405
if not 0.0 < self.top_p <= 1.0:
406406
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
407-
if self.top_k < -1 or self.top_k == 0:
408-
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
407+
# quietly accept -1 as disabled, but prefer 0
408+
if self.top_k < -1:
409+
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
409410
f"got {self.top_k}.")
410411
if not isinstance(self.top_k, int):
411412
raise TypeError(

vllm/worker/neuron_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _convert_to_neuron_sampling_params(
348348
if temperature == 0.0:
349349
# Enable greedy sampling on zero temperature
350350
return (1, 1.0, 1.0)
351-
if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
351+
if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
352352
top_k = self._MAX_NEURON_SAMPLING_TOP_K
353353

354354
return (top_k, top_p, temperature)

vllm/worker/tpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def _prepare_sample(
525525
"Top-p sampling is currently disabled for the TPU backend "
526526
"due to performance issues.")
527527
p.append(sampling_params.top_p)
528-
if sampling_params.top_k != -1:
528+
if sampling_params.top_k > 0:
529529
raise NotImplementedError(
530530
"Top-k sampling is currently disabled for the TPU backend "
531531
"due to performance issues.")

0 commit comments

Comments
 (0)