Skip to content

Commit e7bd870

Browse files
committed
llama : add llama_model_decoder_start_token() API call that returns decoder_start_token_id
llama : add llama_model_has_encoder() API call llama-cli : use llama_model_has_encoder() and llama_model_decoder_start_token() API calls
1 parent cd9a969 commit e7bd870

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

examples/main/main.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -501,16 +501,22 @@ int main(int argc, char ** argv) {
501501
exit(1);
502502
}
503503

504-
int enc_input_size = embd_inp.size();
505-
llama_token * enc_input_buf = embd_inp.data();
504+
if (llama_model_has_encoder(model)) {
505+
int enc_input_size = embd_inp.size();
506+
llama_token * enc_input_buf = embd_inp.data();
506507

507-
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
508-
LOG_TEE("%s : failed to eval\n", __func__);
509-
return 1;
510-
}
508+
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
509+
LOG_TEE("%s : failed to eval\n", __func__);
510+
return 1;
511+
}
511512

512-
embd_inp.clear();
513-
embd_inp.push_back(llama_token_pad(model));
513+
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
514+
if (decoder_start_token_id == -1) {
515+
decoder_start_token_id = llama_token_bos(model);
516+
}
517+
embd_inp.clear();
518+
embd_inp.push_back(decoder_start_token_id);
519+
}
514520

515521
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
516522
// predict

llama.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ enum llm_kv {
296296
LLM_KV_EXPERT_WEIGHTS_SCALE,
297297
LLM_KV_POOLING_TYPE,
298298
LLM_KV_LOGIT_SCALE,
299+
LLM_KV_DECODER_START_TOKEN_ID,
299300

300301
LLM_KV_ATTENTION_HEAD_COUNT,
301302
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -384,6 +385,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
384385
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
385386
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
386387
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
388+
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
387389

388390
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
389391
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -1908,6 +1910,7 @@ struct llama_hparams {
19081910
uint32_t n_expert_used = 0;
19091911
uint32_t n_vocab_type = 0; // for BERT-style token types
19101912
uint32_t n_rel_attn_bkts = 0;
1913+
int32_t decoder_start_token_id = -1;
19111914

19121915
uint32_t n_layer_dense_lead = 0;
19131916
uint32_t n_lora_q = 0;
@@ -4606,6 +4609,10 @@ static void llm_load_hparams(
46064609
{
46074610
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
46084611
ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
4612+
uint32_t decoder_start_token_id;
4613+
if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, decoder_start_token_id, false)) {
4614+
hparams.decoder_start_token_id = decoder_start_token_id;
4615+
}
46094616
model.type = e_model::MODEL_UNKNOWN;
46104617
} break;
46114618
default: (void)0;
@@ -17872,6 +17879,17 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch
1787217879
return it->second;
1787317880
}
1787417881

17882+
bool llama_model_has_encoder(const struct llama_model * model) {
17883+
switch (model->arch) {
17884+
case LLM_ARCH_T5: return true;
17885+
default: return false;
17886+
}
17887+
}
17888+
17889+
llama_token llama_model_decoder_start_token(const struct llama_model * model) {
17890+
return model->hparams.decoder_start_token_id;
17891+
}
17892+
1787517893
uint32_t llama_model_quantize(
1787617894
const char * fname_inp,
1787717895
const char * fname_out,

llama.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,13 @@ extern "C" {
482482
// Get a llama model tensor
483483
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
484484

485+
// Returns true if the model contains an encoder that requires llama_encode() call
486+
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
487+
488+
// For encoder-decoder models, this function returns id of the token that must be provided
489+
// to the decoder to start generating output sequence. For other models, it returns -1.
490+
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
491+
485492
// Returns 0 on success
486493
LLAMA_API uint32_t llama_model_quantize(
487494
const char * fname_inp,

0 commit comments

Comments
 (0)