Skip to content

Commit 7094ea5

Browse files
authored
whisper : use flash attention (#2152)
* whisper : use flash attention in the encoder * whisper : add kv_pad * whisper : remove extra backend instance (huh?) * whisper : use FA for cross-attention * whisper : use FA for self-attention * whisper : simplify encoder FA * whisper : add flash_attn runtime parameter * scripts : add bench log * scripts : add M1 Pro bench log
1 parent 9d5771a commit 7094ea5

File tree

13 files changed

+658
-173
lines changed

13 files changed

+658
-173
lines changed

examples/bench/bench.cpp

+11-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ struct whisper_params {
1212

1313
std::string model = "models/ggml-base.en.bin";
1414

15-
bool use_gpu = true;
15+
bool use_gpu = true;
16+
bool flash_attn = false;
1617
};
1718

1819
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -25,10 +26,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
2526
whisper_print_usage(argc, argv, params);
2627
exit(0);
2728
}
28-
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
29-
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
30-
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
31-
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
29+
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
30+
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
31+
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
32+
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
33+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
3234
else {
3335
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
3436
whisper_print_usage(argc, argv, params);
@@ -49,6 +51,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
4951
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
5052
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
5153
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
54+
fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
5255
fprintf(stderr, " %-7s 0 - whisper\n", "");
5356
fprintf(stderr, " %-7s 1 - memcpy\n", "");
5457
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
@@ -59,7 +62,9 @@ int whisper_bench_full(const whisper_params & params) {
5962
// whisper init
6063

6164
struct whisper_context_params cparams = whisper_context_default_params();
62-
cparams.use_gpu = params.use_gpu;
65+
66+
cparams.use_gpu = params.use_gpu;
67+
cparams.flash_attn = params.flash_attn;
6368

6469
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
6570

examples/command/command.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct whisper_params {
4444
bool print_energy = false;
4545
bool no_timestamps = true;
4646
bool use_gpu = true;
47+
bool flash_attn = false;
4748

4849
std::string language = "en";
4950
std::string model = "models/ggml-base.en.bin";
@@ -80,6 +81,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
8081
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
8182
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
8283
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
84+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
8385
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
8486
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
8587
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
@@ -118,6 +120,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
118120
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
119121
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
120122
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
123+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
121124
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
122125
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
123126
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
@@ -696,7 +699,9 @@ int main(int argc, char ** argv) {
696699
// whisper init
697700

698701
struct whisper_context_params cparams = whisper_context_default_params();
699-
cparams.use_gpu = params.use_gpu;
702+
703+
cparams.use_gpu = params.use_gpu;
704+
cparams.flash_attn = params.flash_attn;
700705

701706
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
702707

examples/lsp/lsp.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ struct whisper_params {
3131
bool print_special = false;
3232
bool print_energy = false;
3333
bool use_gpu = true;
34+
bool flash_attn = false;
3435

3536
std::string language = "en";
3637
std::string model = "models/ggml-base.en.bin";
@@ -74,6 +75,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
7475
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
7576
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
7677
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
78+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
7779
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
7880
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
7981
else {
@@ -105,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
105107
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
106108
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
107109
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
110+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
108111
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
109112
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
110113
fprintf(stderr, "\n");
@@ -436,7 +439,10 @@ int main(int argc, char ** argv) {
436439

437440
// whisper init
438441
struct whisper_context_params cparams = whisper_context_default_params();
439-
cparams.use_gpu = params.use_gpu;
442+
443+
cparams.use_gpu = params.use_gpu;
444+
cparams.flash_attn = params.flash_attn;
445+
440446
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
441447
// init audio
442448

examples/main/main.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct whisper_params {
7070
bool no_timestamps = false;
7171
bool log_score = false;
7272
bool use_gpu = true;
73+
bool flash_attn = false;
7374

7475
std::string language = "en";
7576
std::string prompt;
@@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
168169
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
169170
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
170171
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
171-
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
172+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
173+
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
172174
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
173175
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
174176
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
@@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
234236
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
235237
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
236238
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
239+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
237240
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
238241
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
239242
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
@@ -977,7 +980,9 @@ int main(int argc, char ** argv) {
977980
// whisper init
978981

979982
struct whisper_context_params cparams = whisper_context_default_params();
980-
cparams.use_gpu = params.use_gpu;
983+
984+
cparams.use_gpu = params.use_gpu;
985+
cparams.flash_attn = params.flash_attn;
981986

982987
if (!params.dtw.empty()) {
983988
cparams.dtw_token_timestamps = true;

examples/server/server.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ struct whisper_params {
7575
bool print_progress = false;
7676
bool no_timestamps = false;
7777
bool use_gpu = true;
78+
bool flash_attn = false;
7879

7980
std::string language = "en";
8081
std::string prompt = "";
@@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
178179
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
179180
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
180181
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
182+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
181183
// server params
182184
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
183185
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@@ -502,7 +504,10 @@ int main(int argc, char ** argv) {
502504
}
503505
// whisper init
504506
struct whisper_context_params cparams = whisper_context_default_params();
505-
cparams.use_gpu = params.use_gpu;
507+
508+
cparams.use_gpu = params.use_gpu;
509+
cparams.flash_attn = params.flash_attn;
510+
506511
if (!params.dtw.empty()) {
507512
cparams.dtw_token_timestamps = true;
508513
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;

examples/stream/stream.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct whisper_params {
3636
bool tinydiarize = false;
3737
bool save_audio = false; // save audio to wav file
3838
bool use_gpu = true;
39+
bool flash_attn = false;
3940

4041
std::string language = "en";
4142
std::string model = "models/ggml-base.en.bin";
@@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
7273
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
7374
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
7475
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
76+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
7577

7678
else {
7779
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
109111
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
110112
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
111113
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
114+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
112115
fprintf(stderr, "\n");
113116
}
114117

@@ -153,7 +156,9 @@ int main(int argc, char ** argv) {
153156
}
154157

155158
struct whisper_context_params cparams = whisper_context_default_params();
156-
cparams.use_gpu = params.use_gpu;
159+
160+
cparams.use_gpu = params.use_gpu;
161+
cparams.flash_attn = params.flash_attn;
157162

158163
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
159164

examples/talk-llama/talk-llama.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct whisper_params {
6666
bool no_timestamps = true;
6767
bool verbose_prompt = false;
6868
bool use_gpu = true;
69+
bool flash_attn = false;
6970

7071
std::string person = "Georgi";
7172
std::string bot_name = "LLaMA";
@@ -105,6 +106,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
105106
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
106107
else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; }
107108
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
109+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
108110
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
109111
else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
110112
else if (arg == "--session") { params.path_session = argv[++i]; }
@@ -123,7 +125,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
123125
}
124126
}
125127
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
126-
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
127128
else {
128129
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
129130
whisper_print_usage(argc, argv, params);
@@ -154,6 +155,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
154155
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
155156
fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
156157
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
158+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
157159
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
158160
fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
159161
fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
@@ -285,7 +287,9 @@ int main(int argc, char ** argv) {
285287
// whisper init
286288

287289
struct whisper_context_params cparams = whisper_context_default_params();
288-
cparams.use_gpu = params.use_gpu;
290+
291+
cparams.use_gpu = params.use_gpu;
292+
cparams.flash_attn = params.flash_attn;
289293

290294
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
291295
if (!ctx_wsp) {
@@ -316,6 +320,7 @@ int main(int argc, char ** argv) {
316320
lcparams.n_ctx = 2048;
317321
lcparams.seed = 1;
318322
lcparams.n_threads = params.n_threads;
323+
lcparams.flash_attn = params.flash_attn;
319324

320325
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
321326

examples/talk/talk.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct whisper_params {
3232
bool print_energy = false;
3333
bool no_timestamps = true;
3434
bool use_gpu = true;
35+
bool flash_attn = false;
3536

3637
std::string person = "Santa";
3738
std::string language = "en";
@@ -64,6 +65,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
6465
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
6566
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
6667
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
68+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
6769
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
6870
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
6971
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
@@ -99,6 +101,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
99101
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
100102
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
101103
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
104+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
102105
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
103106
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
104107
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
@@ -188,7 +191,9 @@ int main(int argc, char ** argv) {
188191

189192
// whisper init
190193
struct whisper_context_params cparams = whisper_context_default_params();
191-
cparams.use_gpu = params.use_gpu;
194+
195+
cparams.use_gpu = params.use_gpu;
196+
cparams.flash_attn = params.flash_attn;
192197

193198
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
194199

0 commit comments

Comments
 (0)