Skip to content

Commit b3c5d05

Browse files
committed
change input
1 parent ac5520c commit b3c5d05

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

vllm/engine/async_llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ async def beam_search(
10581058
tokenizedLength = len(tokenizedPrompt)
10591059

10601060
sort_beams_key = create_sort_beams_key_function(
1061-
tokenizer, length_penalty=length_penalty)
1061+
tokenizer.eos_token_id, length_penalty)
10621062

10631063
beam_search_params = SamplingParams(logprobs=2 * beam_width,
10641064
max_tokens=1,

vllm/engine/multiprocessing/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ async def beam_search(
463463
tokenizedLength = len(tokenizedPrompt)
464464

465465
sort_beams_key = create_sort_beams_key_function(
466-
tokenizer, length_penalty=length_penalty)
466+
tokenizer.eos_token_id, length_penalty)
467467

468468
beam_search_params = SamplingParams(logprobs=2 * beam_width,
469469
max_tokens=1,

vllm/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,10 +1414,10 @@ def get_beam_search_score(
14141414
return cumulative_logprob / (seq_len**length_penalty)
14151415

14161416

1417-
def create_sort_beams_key_function(tokenizer, length_penalty):
1417+
def create_sort_beams_key_function(eos_token_id: int, length_penalty):
14181418

14191419
def sort_beams_key(x: BeamSearchSequence) -> float:
1420-
return get_beam_search_score(x.tokens, x.cum_logprob,
1421-
tokenizer.eos_token_id, length_penalty)
1420+
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
1421+
length_penalty)
14221422

14231423
return sort_beams_key

0 commit comments

Comments
 (0)