File tree 1 file changed +3
-1
lines changed
1 file changed +3
-1
lines changed Original file line number Diff line number Diff line change @@ -396,6 +396,7 @@ def beam_search(
396
396
beam_width : int ,
397
397
max_tokens : int ,
398
398
ignore_eos : bool = False ,
399
+ temperature : float = 0.0 ,
399
400
) -> List [BeamSearchOutput ]:
400
401
"""
401
402
Generate sequences using beam search.
@@ -405,6 +406,7 @@ def beam_search(
405
406
of token IDs.
406
407
beam_width: The number of beams to keep at each step.
407
408
max_tokens: The max number of tokens to generate for each prompt.
409
+ temperature: The temperature to use for generation.
408
410
409
411
TODO: how does beam search work together with length penalty, frequency
410
412
penalty, and stopping criteria, etc.?
@@ -416,7 +418,7 @@ def beam_search(
416
418
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
417
419
beam_search_params = SamplingParams (logprobs = 2 * beam_width ,
418
420
max_tokens = 1 ,
419
- temperature = 0.0 )
421
+ temperature = temperature )
420
422
instances : List [BeamSearchInstance ] = []
421
423
422
424
for prompt in prompts :
You can’t perform that action at this time.
0 commit comments