Skip to content

Commit eb4c17e

Browse files
committed
llama : add n_enc_output field in llm_build_context containing the number of embeddings generated by the encoder
1 parent 684160a commit eb4c17e

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

Diff for: llama.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -7347,6 +7347,7 @@ struct llm_build_context {
73477347
const int32_t n_tokens;
73487348
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
73497349
const int32_t n_outputs;
7350+
const int32_t n_enc_outputs;
73507351
const int32_t kv_head; // index of where we store new KV data in the cache
73517352
const int32_t n_ctx_orig;
73527353

@@ -7396,6 +7397,7 @@ struct llm_build_context {
73967397
n_tokens (batch.n_tokens),
73977398
n_kv (worst_case ? kv_self.size : kv_self.n),
73987399
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
7400+
n_enc_outputs (worst_case ? n_tokens : lctx.encoder_output.size() / hparams.n_embd),
73997401
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
74007402
n_ctx_orig (cparams.n_ctx_orig_yarn),
74017403
flash_attn (cparams.flash_attn),
@@ -7660,14 +7662,14 @@ struct llm_build_context {
76607662

76617663
struct ggml_tensor * llm_build_inp_enc_output() {
76627664
const int64_t n_embd = hparams.n_embd;
7663-
lctx.inp_enc_output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd);
7665+
lctx.inp_enc_output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc_outputs);
76647666
ggml_set_input(lctx.inp_enc_output);
76657667
cb(lctx.inp_enc_output, "enc_output", -1);
76667668
return lctx.inp_enc_output;
76677669
}
76687670

76697671
struct ggml_tensor * llm_build_inp_cross_KQ_mask() {
7670-
lctx.inp_cross_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7672+
lctx.inp_cross_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc_outputs, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
76717673
ggml_set_input(lctx.inp_cross_KQ_mask);
76727674
cb(lctx.inp_cross_KQ_mask, "enc_mask", -1);
76737675
return lctx.inp_cross_KQ_mask;
@@ -11717,7 +11719,6 @@ struct llm_build_context {
1171711719
const int64_t n_embd_head = hparams.n_embd_head_v;
1171811720
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
1171911721
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
11720-
const int32_t n_enc_output = lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd;
1172111722

1172211723
struct ggml_tensor * cur;
1172311724
struct ggml_tensor * inpL;
@@ -11926,7 +11927,7 @@ struct llm_build_context {
1192611927
cb(Vcur, "Vcur", il);
1192711928

1192811929
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
11929-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_enc_output);
11930+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_enc_outputs);
1193011931

1193111932
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
1193211933
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
@@ -11937,10 +11938,10 @@ struct llm_build_context {
1193711938
kq = ggml_soft_max_ext(ctx0, kq, enc_KQ_mask, 1.0f, hparams.f_max_alibi_bias);
1193811939
cb(kq, "kq_soft_max_ext", il);
1193911940

11940-
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_enc_output)));
11941+
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_enc_outputs)));
1194111942
cb(v, "v", il);
1194211943

11943-
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_enc_output, n_embd_head, n_head_kv), kq);
11944+
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_enc_outputs, n_embd_head, n_head_kv), kq);
1194411945
cb(kqv, "kqv", il);
1194511946

1194611947
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);

0 commit comments

Comments
 (0)