Skip to content

Commit 8551c44

Browse files
authored
context : always use non-causal attention for encoder graphs (ggml-org#12447)
* context : always use non-causal attention for encoder graphs ggml-ci * context : move the change to llama_context::encode() ggml-ci
1 parent 35cae5b commit 8551c44

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/llama-context.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,13 +1057,22 @@ int llama_context::encode(llama_batch & inp_batch) {
10571057
ggml_backend_sched_reset(sched.get());
10581058
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
10591059

1060+
const auto causal_attn_org = cparams.causal_attn;
1061+
1062+
// always use non-causal attention for encoder graphs
1063+
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
1064+
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
1065+
cparams.causal_attn = false;
1066+
10601067
auto * gf = graph_init();
10611068
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
10621069

10631070
ggml_backend_sched_alloc_graph(sched.get(), gf);
10641071

10651072
res->set_inputs(&ubatch);
10661073

1074+
cparams.causal_attn = causal_attn_org;
1075+
10671076
const auto compute_status = graph_compute(gf, n_tokens > 1);
10681077
switch (compute_status) {
10691078
case GGML_STATUS_SUCCESS:

0 commit comments

Comments
 (0)