Skip to content

Commit 65192c0

Browse files
authored
Merge pull request #91 from ggerganov/master
b2266
2 parents 8c4c1d0 + e3965cf commit 65192c0

38 files changed

+1659
-916
lines changed

Diff for: .github/workflows/build.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,7 @@ jobs:
669669
run: |
670670
cd examples/llama.android
671671
672-
# Skip armeabi-v7a for now (https://github.com/llvm/llvm-project/issues/65820).
673-
./gradlew build --no-daemon -Pskip-armeabi-v7a
672+
./gradlew build --no-daemon
674673
675674
# freeBSD-latest:
676675
# runs-on: macos-12

Diff for: CMakeLists.txt

+8-2
Original file line numberDiff line numberDiff line change
@@ -936,10 +936,16 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
936936
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)
937937
endif()
938938
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
939-
# Raspberry Pi 2
940-
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
939+
if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
940+
# Android armeabi-v7a
941+
list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)
942+
else()
943+
# Raspberry Pi 2
944+
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
945+
endif()
941946
endif()
942947
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
948+
# Android arm64-v8a
943949
# Raspberry Pi 3, 4, Zero 2 (32-bit)
944950
list(APPEND ARCH_FLAGS -mno-unaligned-access)
945951
endif()

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ $(info I CC: $(shell $(CC) --version | head -n 1))
597597
$(info I CXX: $(shell $(CXX) --version | head -n 1))
598598
ifdef LLAMA_CUBLAS
599599
$(info I NVCC: $(shell $(NVCC) --version | tail -n 1))
600-
CUDA_VERSION := $(shell nvcc --version | grep -oP 'release (\K[0-9]+\.[0-9])')
600+
CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])')
601601
ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1)
602602
ifndef CUDA_DOCKER_ARCH
603603
ifndef CUDA_POWER_ARCH

Diff for: README.md

+4
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ Typically finetunes of the base models below are supported as well.
114114
- [x] [MobileVLM 1.7B/3B models](https://huggingface.co/models?search=mobileVLM)
115115
- [x] [Yi-VL](https://huggingface.co/models?search=Yi-VL)
116116

117+
**HTTP server**
118+
119+
[llama.cpp web server](./examples/server) is a lightweight [OpenAI API](https://github.com/openai/openai-openapi) compatible HTTP server that can be used to serve local models and easily connect them to existing clients.
117120

118121
**Bindings:**
119122

@@ -155,6 +158,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
155158
- [semperai/amica](https://github.com/semperai/amica)
156159
- [withcatai/catai](https://github.com/withcatai/catai)
157160
- [Mobile-Artificial-Intelligence/maid](https://github.com/Mobile-Artificial-Intelligence/maid) (MIT)
161+
- [Msty](https://msty.app) (proprietary)
158162

159163
---
160164

Diff for: common/common.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
295295
break;
296296
}
297297
std::string value(argv[i]);
298-
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
299-
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
300-
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
298+
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
299+
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
300+
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
301301
else { invalid_param = true; break; }
302302
} else if (arg == "--rope-scale") {
303303
if (++i >= argc) {
@@ -630,11 +630,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
630630
}
631631
std::string arg_next = argv[i];
632632
if (arg_next == "none") {
633-
params.split_mode = LLAMA_SPLIT_NONE;
633+
params.split_mode = LLAMA_SPLIT_MODE_NONE;
634634
} else if (arg_next == "layer") {
635-
params.split_mode = LLAMA_SPLIT_LAYER;
635+
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
636636
} else if (arg_next == "row") {
637-
params.split_mode = LLAMA_SPLIT_ROW;
637+
params.split_mode = LLAMA_SPLIT_MODE_ROW;
638638
} else {
639639
invalid_param = true;
640640
break;
@@ -837,15 +837,15 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
837837
sep++;
838838
if (strncmp(sep, "int:", 4) == 0) {
839839
sep += 4;
840-
kvo.tag = LLAMA_KV_OVERRIDE_INT;
840+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
841841
kvo.int_value = std::atol(sep);
842842
} else if (strncmp(sep, "float:", 6) == 0) {
843843
sep += 6;
844-
kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
844+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
845845
kvo.float_value = std::atof(sep);
846846
} else if (strncmp(sep, "bool:", 5) == 0) {
847847
sep += 5;
848-
kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
848+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
849849
if (std::strcmp(sep, "true") == 0) {
850850
kvo.bool_value = true;
851851
} else if (std::strcmp(sep, "false") == 0) {

Diff for: common/common.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct gpt_params {
6161
float p_split = 0.1f; // speculative decoding split probability
6262
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
6363
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
64-
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
64+
llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
6565
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
6666
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
6767
int32_t n_beams = 0; // if non-zero then use beam search of given width.
@@ -75,7 +75,7 @@ struct gpt_params {
7575
float yarn_beta_fast = 32.0f; // YaRN low correction dim
7676
float yarn_beta_slow = 1.0f; // YaRN high correction dim
7777
int32_t yarn_orig_ctx = 0; // YaRN original context length
78-
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
78+
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
7979
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
8080

8181
// // sampling parameters

Diff for: common/sampling.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ static llama_token llama_sampling_sample_impl(
266266
// }
267267
//}
268268

269-
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
269+
//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
270270
}
271271
}
272272

Diff for: common/train.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct train_state * init_train_state() {
3131

3232
state->opt = new struct ggml_opt_context;
3333
state->opt->ctx = NULL;
34-
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
34+
state->opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
3535
state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
3636
state->opt->loss_after = 0.0f;
3737

@@ -556,7 +556,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g
556556
std::string opt_type;
557557
GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
558558
if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
559-
opt->params.type = GGML_OPT_ADAM;
559+
opt->params.type = GGML_OPT_TYPE_ADAM;
560560

561561
GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
562562
GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
@@ -568,7 +568,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g
568568
copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
569569
copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
570570
} else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
571-
opt->params.type = GGML_OPT_LBFGS;
571+
opt->params.type = GGML_OPT_TYPE_LBFGS;
572572

573573
GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
574574
GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
@@ -603,7 +603,7 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context *
603603
gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
604604

605605
switch (opt->params.type) {
606-
case GGML_OPT_ADAM:
606+
case GGML_OPT_TYPE_ADAM:
607607
{
608608
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
609609
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
@@ -622,7 +622,7 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context *
622622
gguf_add_tensor(fctx, opt->adam.pf);
623623
}
624624
} break;
625-
case GGML_OPT_LBFGS:
625+
case GGML_OPT_TYPE_LBFGS:
626626
{
627627
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
628628
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);

Diff for: convert-hf-to-gguf.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def from_model_architecture(model_architecture):
192192
return RefactModel
193193
if model_architecture == "PersimmonForCausalLM":
194194
return PersimmonModel
195-
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
195+
if model_architecture in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
196196
return StableLMModel
197197
if model_architecture == "QWenLMHeadModel":
198198
return QwenModel
@@ -253,7 +253,7 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
253253
return gguf.MODEL_ARCH.REFACT
254254
if arch == "PersimmonForCausalLM":
255255
return gguf.MODEL_ARCH.PERSIMMON
256-
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
256+
if arch in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
257257
return gguf.MODEL_ARCH.STABLELM
258258
if arch == "QWenLMHeadModel":
259259
return gguf.MODEL_ARCH.QWEN
@@ -1074,10 +1074,11 @@ def set_gguf_parameters(self):
10741074
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
10751075
self.gguf_writer.add_block_count(block_count)
10761076
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
1077-
self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"])))
1077+
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
1078+
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
10781079
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
10791080
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
1080-
self.gguf_writer.add_layer_norm_eps(1e-5)
1081+
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
10811082

10821083

10831084
class MixtralModel(Model):

Diff for: examples/baby-llama/baby-llama.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,7 @@ int main(int argc, char ** argv) {
15471547

15481548
float error_before_opt = ggml_get_f32_1d(e, 0);
15491549

1550-
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
1550+
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_TYPE_LBFGS);
15511551
opt_params_lbfgs.print_forward_graph = false;
15521552
opt_params_lbfgs.print_backward_graph = false;
15531553
opt_params_lbfgs.lbfgs.n_iter = 16;

Diff for: examples/finetune/finetune.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1531,7 +1531,7 @@ int main(int argc, char ** argv) {
15311531
lora.hparams.n_rank_output = n_rank_output;
15321532

15331533
// set opt params from command line
1534-
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
1534+
opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
15351535
opt->params.print_forward_graph = false;
15361536
opt->params.print_backward_graph = false;
15371537
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;

Diff for: examples/infill/infill.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,8 @@ int main(int argc, char ** argv) {
447447
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
448448
n_past, n_left, n_ctx, params.n_keep, n_discard);
449449

450-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
451-
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
450+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
451+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
452452

453453
n_past -= n_discard;
454454

Diff for: examples/llama-bench/llama-bench.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ static const char * output_format_str(output_formats format) {
157157

158158
static const char * split_mode_str(llama_split_mode mode) {
159159
switch (mode) {
160-
case LLAMA_SPLIT_NONE: return "none";
161-
case LLAMA_SPLIT_LAYER: return "layer";
162-
case LLAMA_SPLIT_ROW: return "row";
160+
case LLAMA_SPLIT_MODE_NONE: return "none";
161+
case LLAMA_SPLIT_MODE_LAYER: return "layer";
162+
case LLAMA_SPLIT_MODE_ROW: return "row";
163163
default: GGML_ASSERT(!"invalid split mode");
164164
}
165165
}
@@ -193,7 +193,7 @@ static const cmd_params cmd_params_defaults = {
193193
/* type_v */ {GGML_TYPE_F16},
194194
/* n_threads */ {get_num_physical_cores()},
195195
/* n_gpu_layers */ {99},
196-
/* split_mode */ {LLAMA_SPLIT_LAYER},
196+
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
197197
/* main_gpu */ {0},
198198
/* no_kv_offload */ {false},
199199
/* mul_mat_q */ {true},
@@ -358,11 +358,11 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
358358
for (const auto & m : p) {
359359
llama_split_mode mode;
360360
if (m == "none") {
361-
mode = LLAMA_SPLIT_NONE;
361+
mode = LLAMA_SPLIT_MODE_NONE;
362362
} else if (m == "layer") {
363-
mode = LLAMA_SPLIT_LAYER;
363+
mode = LLAMA_SPLIT_MODE_LAYER;
364364
} else if (m == "row") {
365-
mode = LLAMA_SPLIT_ROW;
365+
mode = LLAMA_SPLIT_MODE_ROW;
366366
} else {
367367
invalid_param = true;
368368
break;

Diff for: examples/llama.android/app/build.gradle.kts

+2-6
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,8 @@ android {
2121
useSupportLibrary = true
2222
}
2323
ndk {
24-
// Workaround for https://github.com/llvm/llvm-project/issues/65820
25-
// affecting armeabi-v7a. Skip armeabi-v7a when invoked with
26-
// -Pskip-armeabi-v7a (e.g., ./gradlew build -Pskip-armeabi-v7a).
27-
if (project.hasProperty("skip-armeabi-v7a")) {
28-
abiFilters += listOf("arm64-v8a", "x86_64", "x86")
29-
}
24+
// Add NDK properties if wanted, e.g.
25+
// abiFilters += listOf("arm64-v8a")
3026
}
3127
externalNativeBuild {
3228
cmake {

Diff for: examples/llava/llava.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
152152

153153
ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip);
154154
model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]);
155-
if (newline_tmp->backend != GGML_BACKEND_CPU) {
155+
if (newline_tmp->backend != GGML_BACKEND_TYPE_CPU) {
156156
if (newline_tmp->buffer == NULL) {
157157
printf("newline_tmp tensor buffer is NULL\n");
158158
}

Diff for: examples/main/main.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,8 @@ int main(int argc, char ** argv) {
548548
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
549549
n_past, n_left, n_ctx, params.n_keep, n_discard);
550550

551-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
552-
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
551+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
552+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
553553

554554
n_past -= n_discard;
555555

@@ -576,9 +576,9 @@ int main(int argc, char ** argv) {
576576
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
577577
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
578578

579-
llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd);
580-
llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
581-
llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
579+
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
580+
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
581+
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
582582

583583
n_past -= bd;
584584

Diff for: examples/passkey/passkey.cpp

+15-10
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ int main(int argc, char ** argv) {
126126
const int n_batch = ctx_params.n_batch;
127127
const int n_batch_grp = ctx_params.n_batch/n_grp;
128128

129-
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
129+
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);
130130

131131
// print the prompt token-by-token
132132

@@ -146,10 +146,11 @@ int main(int argc, char ** argv) {
146146
const int ib = i/n_batch - 1;
147147
const int bd = n_batch_grp*(n_grp - 1);
148148

149-
llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
150-
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
149+
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
150+
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
151+
llama_kv_cache_update (ctx);
151152

152-
n_past -= bd;
153+
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
153154
}
154155

155156
llama_batch_clear(batch);
@@ -179,10 +180,12 @@ int main(int argc, char ** argv) {
179180

180181
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
181182

182-
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
183-
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
183+
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
184+
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
185+
llama_kv_cache_defrag (ctx);
186+
llama_kv_cache_update (ctx);
184187

185-
n_past -= n_discard;
188+
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
186189

187190
llama_batch_clear(batch);
188191

@@ -208,10 +211,12 @@ int main(int argc, char ** argv) {
208211
if (n_discard > 0) {
209212
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
210213

211-
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
212-
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
214+
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
215+
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
216+
llama_kv_cache_defrag (ctx);
217+
llama_kv_cache_update (ctx);
213218

214-
n_past -= n_discard;
219+
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
215220
}
216221
}
217222

0 commit comments

Comments
 (0)