@@ -23,6 +23,7 @@ struct whisper_params {
23
23
int32_t capture_id = -1 ;
24
24
int32_t max_tokens = 32 ;
25
25
int32_t audio_ctx = 0 ;
26
+ int32_t beam_size = -1 ;
26
27
27
28
float vad_thold = 0 .6f ;
28
29
float freq_thold = 100 .0f ;
@@ -59,6 +60,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
59
60
else if (arg == " -c" || arg == " --capture" ) { params.capture_id = std::stoi (argv[++i]); }
60
61
else if (arg == " -mt" || arg == " --max-tokens" ) { params.max_tokens = std::stoi (argv[++i]); }
61
62
else if (arg == " -ac" || arg == " --audio-ctx" ) { params.audio_ctx = std::stoi (argv[++i]); }
63
+ else if (arg == " -bs" || arg == " --beam-size" ) { params.beam_size = std::stoi (argv[++i]); }
62
64
else if (arg == " -vth" || arg == " --vad-thold" ) { params.vad_thold = std::stof (argv[++i]); }
63
65
else if (arg == " -fth" || arg == " --freq-thold" ) { params.freq_thold = std::stof (argv[++i]); }
64
66
else if (arg == " -tr" || arg == " --translate" ) { params.translate = true ; }
@@ -96,6 +98,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
96
98
fprintf (stderr, " -c ID, --capture ID [%-7d] capture device ID\n " , params.capture_id );
97
99
fprintf (stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n " , params.max_tokens );
98
100
fprintf (stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n " , params.audio_ctx );
101
+ fprintf (stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n " , params.beam_size );
99
102
fprintf (stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n " , params.vad_thold );
100
103
fprintf (stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n " , params.freq_thold );
101
104
fprintf (stderr, " -tr, --translate [%-7s] translate from source language to english\n " , params.translate ? " true" : " false" );
@@ -298,7 +301,7 @@ int main(int argc, char ** argv) {
298
301
299
302
// run the inference
300
303
{
301
- whisper_full_params wparams = whisper_full_default_params (WHISPER_SAMPLING_GREEDY);
304
+ whisper_full_params wparams = whisper_full_default_params (params. beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY);
302
305
303
306
wparams.print_progress = false ;
304
307
wparams.print_special = params.print_special ;
@@ -309,6 +312,7 @@ int main(int argc, char ** argv) {
309
312
wparams.max_tokens = params.max_tokens ;
310
313
wparams.language = params.language .c_str ();
311
314
wparams.n_threads = params.n_threads ;
315
+ wparams.beam_search .beam_size = params.beam_size ;
312
316
313
317
wparams.audio_ctx = params.audio_ctx ;
314
318
0 commit comments