Skip to content

Commit 3c6f8fc

Browse files
Add a new param to set the self attn text context factor ('max-decoders')
Add a param to set the text context factor. No change of behavior: same default (3). Resolve: ggerganov#2334
1 parent 69339af commit 3c6f8fc

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

examples/main/main.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct whisper_params {
3838
int32_t max_len = 0;
3939
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
4040
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
41+
int32_t max_decoders = whisper_context_default_params().max_decoders;
4142
int32_t audio_ctx = 0;
4243

4344
float word_thold = 0.01f;
@@ -131,6 +132,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
131132
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
132133
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
133134
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
135+
else if (arg == "-md" || arg == "--max-decoders") { params.max_decoders = std::stoi(argv[++i]); }
134136
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
135137
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
136138
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
@@ -198,6 +200,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
198200
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
199201
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
200202
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
203+
fprintf(stderr, " -md N, --max-decoders N [%-7d] Max decoders, used to set the text context cache factor\n", params.max_decoders);
201204
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
202205
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
203206
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
@@ -981,6 +984,7 @@ int main(int argc, char ** argv) {
981984

982985
cparams.use_gpu = params.use_gpu;
983986
cparams.flash_attn = params.flash_attn;
987+
cparams.max_decoders = params.max_decoders;
984988

985989
if (!params.dtw.empty()) {
986990
cparams.dtw_token_timestamps = true;

include/whisper.h

+2
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ extern "C" {
126126
struct whisper_aheads dtw_aheads;
127127

128128
size_t dtw_mem_size; // TODO: remove
129+
130+
int max_decoders; // to be used to setup text context factor
129131
};
130132

131133
typedef struct whisper_token_data {

src/whisper.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ static bool whisper_kv_cache_find_slot(
10051005
}
10061006

10071007
if (n_tested >= n_ctx) {
1008-
//WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
1008+
WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens. n_tested=%d n_ctx=%d cache.head=%d\n", __func__, n_tokens, n_tested, n_ctx, cache.head);
10091009
return false;
10101010
}
10111011
}
@@ -3408,9 +3408,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
34083408
whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
34093409
}
34103410

3411-
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3412-
// in theory, there can be a case where this is not enough, but in practice it should always be enough
3413-
const int factor = 3;
3411+
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx (default value)
3412+
// Note: there are cases where 3 is not enough specially when increasing beamsize
3413+
const int factor = ctx->params.max_decoders;
3414+
3415+
WHISPER_LOG_DEBUG("%s: init self-attn cache: n_ctx: %d factor: %d\n", __func__, factor*ctx->model.hparams.n_text_ctx, factor);
34143416

34153417
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
34163418
ctx->model.hparams.n_text_state,
@@ -3635,6 +3637,7 @@ struct whisper_context_params whisper_context_default_params() {
36353637
/*.heads =*/ NULL,
36363638
},
36373639
/*.dtw_mem_size =*/ 1024*1024*128,
3640+
/* max_decoders =*/ 3
36383641
};
36393642
return result;
36403643
}
@@ -3732,6 +3735,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
37323735
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
37333736
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
37343737
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3738+
WHISPER_LOG_INFO("%s: max-decoders = %d\n", __func__, params.max_decoders);
37353739

37363740
whisper_context * ctx = new whisper_context;
37373741
ctx->params = params;

0 commit comments

Comments
 (0)