@@ -926,65 +926,75 @@ const char* InferToReadbackBuffer(
926
926
if (!buffer.empty ()) {
927
927
const MatchTrie::MatchInfo matchInfo = matchingTrie.CheckBuffer (buffer);
928
928
929
- if (matchInfo.result == MatchTrie::MatchResult::NO) {
930
- WriteToReadbackBuffer (readbackBufferPtr, strdup (buffer.c_str ()), newTokenId);
931
- response += buffer;
932
- buffer = " " ;
933
-
934
- // Save last known accept point in case we have to rewind back to the last accept.
935
- rewindPos = llama_get_kv_cache_used_cells (context);
936
- rewindTokenId = newTokenId;
937
- rewindTokenCount = tokenCount;
938
-
939
- // If we had a rewind state built, tear it down as we've accepted a sequence.
940
- if (banSampler != nullptr ) {
941
- llama_sampler_free (banSampler);
942
- banSampler = nullptr ;
943
- biases.clear ();
944
- }
945
- } else if (matchInfo.result == MatchTrie::MatchResult::MATCHED_STOP) {
946
- // Matched a stop, return the partial substring and break
947
- std::string partialBuffer = buffer.substr (0 , matchInfo.matchPos );
948
-
949
- WriteToReadbackBuffer (readbackBufferPtr, strdup (partialBuffer.c_str ()), newTokenId);
950
- response += partialBuffer;
951
-
952
- stoppedAt = TokenToPiece (model, newTokenId, decodeSpecial).value_or (" " );
953
- finishReason = " StopString" ;
954
- break ;
955
- } else if (matchInfo.result == MatchTrie::MatchResult::MATCHED_REWIND) {
956
- llama_kv_cache_seq_rm (context, 0 , rewindPos, -1 );
957
-
958
- // Reset the detokenizer too when rewinding
959
- if (readbackBufferPtr->detokenizer ) {
960
- readbackBufferPtr->detokenizer ->reset ();
961
- }
962
-
963
- const auto tokens = Tokenize (model, buffer, false , false );
964
- if (tokens) {
965
- for (const llama_token token : tokens.value ()) {
966
- biases.push_back ({token, -50000 .0f });
929
+ switch (matchInfo.result ) {
930
+ case MatchTrie::MatchResult::NO: {
931
+ WriteToReadbackBuffer (readbackBufferPtr, strdup (buffer.c_str ()), newTokenId);
932
+ response += buffer;
933
+ buffer = " " ;
934
+
935
+ // Save last known accept point in case we have to rewind back to the last accept.
936
+ rewindPos = llama_get_kv_cache_used_cells (context);
937
+ rewindTokenId = newTokenId;
938
+ rewindTokenCount = tokenCount;
939
+
940
+ // If we had a rewind state built, tear it down as we've accepted a sequence.
941
+ if (banSampler != nullptr ) {
942
+ llama_sampler_free (banSampler);
943
+ banSampler = nullptr ;
944
+ biases.clear ();
967
945
}
946
+ break ;
968
947
}
948
+ case MatchTrie::MatchResult::MATCHED_STOP: {
949
+ // Matched a stop, return the partial substring and break
950
+ std::string partialBuffer = buffer.substr (0 , matchInfo.matchPos );
969
951
970
- if (banSampler == nullptr ) {
971
- banSampler = MakeSampler ();
972
- LogitBiasSampler (banSampler, model, static_cast <int32_t >(biases.size ()), biases.data ());
973
- DistSampler (banSampler, seed);
974
- } else {
975
- llama_sampler_chain_remove (banSampler, 1 );
976
- llama_sampler_chain_remove (banSampler, 0 );
977
- LogitBiasSampler (banSampler, model, static_cast <int32_t >(biases.size ()), biases.data ());
978
- DistSampler (banSampler, seed);
979
- }
952
+ WriteToReadbackBuffer (readbackBufferPtr, strdup (partialBuffer.c_str ()), newTokenId);
953
+ response += partialBuffer;
980
954
981
- buffer = " " ;
982
- newTokenId = rewindTokenId;
955
+ stoppedAt = TokenToPiece (model, newTokenId, decodeSpecial).value_or (" " );
956
+ finishReason = " StopString" ;
957
+
958
+ // Signals a break
959
+ isEnd = true ;
960
+ continue ;
961
+ }
962
+ case MatchTrie::MatchResult::MATCHED_REWIND: {
963
+ llama_kv_cache_seq_rm (context, 0 , rewindPos, -1 );
983
964
984
- batch = llama_batch_get_one (&newTokenId, 1 );
985
- std::tie (newTokenId, isEnd) = gen (batch, banSampler);
986
- tokenCount = rewindTokenCount;
987
- continue ;
965
+ // Reset the detokenizer too when rewinding
966
+ if (readbackBufferPtr->detokenizer ) {
967
+ readbackBufferPtr->detokenizer ->reset ();
968
+ }
969
+
970
+ const auto tokens = Tokenize (model, buffer, false , false );
971
+ if (tokens) {
972
+ for (const llama_token token : tokens.value ()) {
973
+ biases.push_back ({token, -50000 .0f });
974
+ }
975
+ }
976
+
977
+ if (banSampler == nullptr ) {
978
+ banSampler = MakeSampler ();
979
+ LogitBiasSampler (banSampler, model, static_cast <int32_t >(biases.size ()), biases.data ());
980
+ DistSampler (banSampler, seed);
981
+ } else {
982
+ llama_sampler_chain_remove (banSampler, 1 );
983
+ llama_sampler_chain_remove (banSampler, 0 );
984
+ LogitBiasSampler (banSampler, model, static_cast <int32_t >(biases.size ()), biases.data ());
985
+ DistSampler (banSampler, seed);
986
+ }
987
+
988
+ buffer = " " ;
989
+ newTokenId = rewindTokenId;
990
+
991
+ batch = llama_batch_get_one (&newTokenId, 1 );
992
+ std::tie (newTokenId, isEnd) = gen (batch, banSampler);
993
+ tokenCount = rewindTokenCount;
994
+ continue ;
995
+ }
996
+ case MatchTrie::MatchResult::MAYBE:
997
+ break ;
988
998
}
989
999
}
990
1000
0 commit comments