Skip to content

Commit 7447231

Browse files
committed
Bindings: Add prefix caching
By preserving a common prompt prefix, processing will only occur on diverging points. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
1 parent 63a8c94 commit 7447231

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

bindings/binding.cpp

+23-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#include <string>
1010
#include <sstream>
1111

12+
// Static vector to hold previous generation tokens
13+
// TODO: Remove in continuous batch implementation
14+
std::vector<llama_token> prevTokens;
15+
1216
void TestPrint(const char* text)
1317
{
1418
std::cout << text << std::endl;
@@ -159,11 +163,13 @@ void FreeSampler(llama_sampler* sampler)
159163

160164
void FreeCtx(llama_context* ctx)
161165
{
166+
prevTokens.empty();
162167
llama_free(ctx);
163168
}
164169

165170
void ClearContextKVCache(llama_context* ctx)
166171
{
172+
prevTokens.empty();
167173
llama_kv_cache_clear(ctx);
168174
}
169175

@@ -451,11 +457,9 @@ std::optional<std::vector<llama_token>> Tokenize(
451457
nullptr, 0, addSpecial, parseSpecial);
452458
std::vector<llama_token> tokenizedPrompt(n_prompt);
453459

454-
bool add_bos = (llama_get_kv_cache_used_cells(context) == 0) & addSpecial;
455-
456460
if (llama_tokenize(&llamaModel->vocab, prompt.data(), prompt.size(),
457461
tokenizedPrompt.data(), tokenizedPrompt.size(),
458-
add_bos, parseSpecial) < 0) {
462+
addSpecial, parseSpecial) < 0) {
459463
std::cerr << "error: failed to tokenize the prompt in TokenizePrompt()" << std::endl;
460464
return std::nullopt;
461465
}
@@ -516,6 +520,15 @@ std::string MakeJsonOutputString(const llama_context* context, const std::string
516520
return ss.str();
517521
}
518522

523+
// From llama.cpp/common/common.cpp
524+
// Returns the point when prompt prefix starts to diverge
525+
size_t common_lcp(const std::vector<llama_token> &a, const std::vector<llama_token> &b) {
526+
size_t i;
527+
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
528+
529+
return i;
530+
}
531+
519532
const char* InferToReadbackBuffer(
520533
const llama_model* model,
521534
llama_sampler* sampler,
@@ -550,7 +563,11 @@ const char* InferToReadbackBuffer(
550563
return false;
551564
}
552565

553-
for (size_t i = 0; i < tokens.size(); i += batchSize) {
566+
// Check when tokens diverge and remove everything after the common prefix
567+
const size_t prefixEnd = common_lcp(tokens, prevTokens);
568+
llama_kv_cache_seq_rm(context, 0, prefixEnd, -1);
569+
570+
for (size_t i = prefixEnd; i < tokens.size(); i += batchSize) {
554571
const size_t remaining = tokens.size() - i;
555572
const size_t currentBatchSize = std::min(remaining, static_cast<size_t>(batchSize));
556573

@@ -564,6 +581,8 @@ const char* InferToReadbackBuffer(
564581
return false;
565582
}
566583
}
584+
585+
prevTokens = tokens;
567586
return true;
568587
};
569588

bindings/bindings.ts

-3
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,6 @@ export class Model {
619619
// Acquire the mutex
620620
using _lock = await this.generationLock.acquire();
621621

622-
// Clear generation cache
623-
this.resetKVCache();
624-
625622
const samplerBuilder = new SamplerBuilder(this.model);
626623
const seed = params.seed && params.seed > 0
627624
? params.seed

0 commit comments

Comments
 (0)