9
9
#include < string>
10
10
#include < sstream>
11
11
12
+ // Static vector to hold previous generation tokens
13
+ // TODO: Remove in continuous batch implementation
14
+ std::vector<llama_token> prevTokens;
15
+
12
16
void TestPrint (const char * text)
13
17
{
14
18
std::cout << text << std::endl;
@@ -159,11 +163,13 @@ void FreeSampler(llama_sampler* sampler)
159
163
160
164
void FreeCtx (llama_context* ctx)
161
165
{
166
+ prevTokens.empty ();
162
167
llama_free (ctx);
163
168
}
164
169
165
170
void ClearContextKVCache (llama_context* ctx)
166
171
{
172
+ prevTokens.empty ();
167
173
llama_kv_cache_clear (ctx);
168
174
}
169
175
@@ -451,11 +457,9 @@ std::optional<std::vector<llama_token>> Tokenize(
451
457
nullptr , 0 , addSpecial, parseSpecial);
452
458
std::vector<llama_token> tokenizedPrompt (n_prompt);
453
459
454
- bool add_bos = (llama_get_kv_cache_used_cells (context) == 0 ) & addSpecial;
455
-
456
460
if (llama_tokenize (&llamaModel->vocab , prompt.data (), prompt.size (),
457
461
tokenizedPrompt.data (), tokenizedPrompt.size (),
458
- add_bos , parseSpecial) < 0 ) {
462
+ addSpecial , parseSpecial) < 0 ) {
459
463
std::cerr << " error: failed to tokenize the prompt in TokenizePrompt()" << std::endl;
460
464
return std::nullopt;
461
465
}
@@ -516,6 +520,15 @@ std::string MakeJsonOutputString(const llama_context* context, const std::string
516
520
return ss.str ();
517
521
}
518
522
523
+ // From llama.cpp/common/common.cpp
524
+ // Returns the point when prompt prefix starts to diverge
525
+ size_t common_lcp (const std::vector<llama_token> &a, const std::vector<llama_token> &b) {
526
+ size_t i;
527
+ for (i = 0 ; i < a.size () && i < b.size () && a[i] == b[i]; i++) {}
528
+
529
+ return i;
530
+ }
531
+
519
532
const char * InferToReadbackBuffer (
520
533
const llama_model* model,
521
534
llama_sampler* sampler,
@@ -550,7 +563,11 @@ const char* InferToReadbackBuffer(
550
563
return false ;
551
564
}
552
565
553
- for (size_t i = 0 ; i < tokens.size (); i += batchSize) {
566
+ // Check when tokens diverge and remove everything after the common prefix
567
+ const size_t prefixEnd = common_lcp (tokens, prevTokens);
568
+ llama_kv_cache_seq_rm (context, 0 , prefixEnd, -1 );
569
+
570
+ for (size_t i = prefixEnd; i < tokens.size (); i += batchSize) {
554
571
const size_t remaining = tokens.size () - i;
555
572
const size_t currentBatchSize = std::min (remaining, static_cast <size_t >(batchSize));
556
573
@@ -564,6 +581,8 @@ const char* InferToReadbackBuffer(
564
581
return false ;
565
582
}
566
583
}
584
+
585
+ prevTokens = tokens;
567
586
return true ;
568
587
};
569
588
0 commit comments