Skip to content

Commit f7eb0a8

Browse files
committed
Bindings: Map banning into a switch
Switches are faster than if/else statements in C++ due to their use of jump tables. Also the code looks much cleaner this way. Use inner scopes for encapsulating local variables. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
1 parent 83bfb70 commit f7eb0a8

File tree

1 file changed

+64
-54
lines changed

1 file changed

+64
-54
lines changed

bindings/binding.cpp

+64-54
Original file line numberDiff line numberDiff line change
@@ -926,65 +926,75 @@ const char* InferToReadbackBuffer(
926926
if (!buffer.empty()) {
927927
const MatchTrie::MatchInfo matchInfo = matchingTrie.CheckBuffer(buffer);
928928

929-
if (matchInfo.result == MatchTrie::MatchResult::NO) {
930-
WriteToReadbackBuffer(readbackBufferPtr, strdup(buffer.c_str()), newTokenId);
931-
response += buffer;
932-
buffer = "";
933-
934-
// Save last known accept point in case we have to rewind back to the last accept.
935-
rewindPos = llama_get_kv_cache_used_cells(context);
936-
rewindTokenId = newTokenId;
937-
rewindTokenCount = tokenCount;
938-
939-
// If we had a rewind state built, tear it down as we've accepted a sequence.
940-
if (banSampler != nullptr) {
941-
llama_sampler_free(banSampler);
942-
banSampler = nullptr;
943-
biases.clear();
944-
}
945-
} else if (matchInfo.result == MatchTrie::MatchResult::MATCHED_STOP) {
946-
// Matched a stop, return the partial substring and break
947-
std::string partialBuffer = buffer.substr(0, matchInfo.matchPos);
948-
949-
WriteToReadbackBuffer(readbackBufferPtr, strdup(partialBuffer.c_str()), newTokenId);
950-
response += partialBuffer;
951-
952-
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value_or("");
953-
finishReason = "StopString";
954-
break;
955-
} else if (matchInfo.result == MatchTrie::MatchResult::MATCHED_REWIND) {
956-
llama_kv_cache_seq_rm(context, 0, rewindPos, -1);
957-
958-
// Reset the detokenizer too when rewinding
959-
if (readbackBufferPtr->detokenizer) {
960-
readbackBufferPtr->detokenizer->reset();
961-
}
962-
963-
const auto tokens = Tokenize(model, buffer, false, false);
964-
if (tokens) {
965-
for (const llama_token token : tokens.value()) {
966-
biases.push_back({token, -50000.0f});
929+
switch (matchInfo.result) {
930+
case MatchTrie::MatchResult::NO: {
931+
WriteToReadbackBuffer(readbackBufferPtr, strdup(buffer.c_str()), newTokenId);
932+
response += buffer;
933+
buffer = "";
934+
935+
// Save last known accept point in case we have to rewind back to the last accept.
936+
rewindPos = llama_get_kv_cache_used_cells(context);
937+
rewindTokenId = newTokenId;
938+
rewindTokenCount = tokenCount;
939+
940+
// If we had a rewind state built, tear it down as we've accepted a sequence.
941+
if (banSampler != nullptr) {
942+
llama_sampler_free(banSampler);
943+
banSampler = nullptr;
944+
biases.clear();
967945
}
946+
break;
968947
}
948+
case MatchTrie::MatchResult::MATCHED_STOP: {
949+
// Matched a stop, return the partial substring and break
950+
std::string partialBuffer = buffer.substr(0, matchInfo.matchPos);
969951

970-
if (banSampler == nullptr) {
971-
banSampler = MakeSampler();
972-
LogitBiasSampler(banSampler, model, static_cast<int32_t>(biases.size()), biases.data());
973-
DistSampler(banSampler, seed);
974-
} else {
975-
llama_sampler_chain_remove(banSampler, 1);
976-
llama_sampler_chain_remove(banSampler, 0);
977-
LogitBiasSampler(banSampler, model, static_cast<int32_t>(biases.size()), biases.data());
978-
DistSampler(banSampler, seed);
979-
}
952+
WriteToReadbackBuffer(readbackBufferPtr, strdup(partialBuffer.c_str()), newTokenId);
953+
response += partialBuffer;
980954

981-
buffer = "";
982-
newTokenId = rewindTokenId;
955+
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value_or("");
956+
finishReason = "StopString";
957+
958+
// Signals a break
959+
isEnd = true;
960+
continue;
961+
}
962+
case MatchTrie::MatchResult::MATCHED_REWIND: {
963+
llama_kv_cache_seq_rm(context, 0, rewindPos, -1);
983964

984-
batch = llama_batch_get_one(&newTokenId, 1);
985-
std::tie(newTokenId, isEnd) = gen(batch, banSampler);
986-
tokenCount = rewindTokenCount;
987-
continue;
965+
// Reset the detokenizer too when rewinding
966+
if (readbackBufferPtr->detokenizer) {
967+
readbackBufferPtr->detokenizer->reset();
968+
}
969+
970+
const auto tokens = Tokenize(model, buffer, false, false);
971+
if (tokens) {
972+
for (const llama_token token : tokens.value()) {
973+
biases.push_back({token, -50000.0f});
974+
}
975+
}
976+
977+
if (banSampler == nullptr) {
978+
banSampler = MakeSampler();
979+
LogitBiasSampler(banSampler, model, static_cast<int32_t>(biases.size()), biases.data());
980+
DistSampler(banSampler, seed);
981+
} else {
982+
llama_sampler_chain_remove(banSampler, 1);
983+
llama_sampler_chain_remove(banSampler, 0);
984+
LogitBiasSampler(banSampler, model, static_cast<int32_t>(biases.size()), biases.data());
985+
DistSampler(banSampler, seed);
986+
}
987+
988+
buffer = "";
989+
newTokenId = rewindTokenId;
990+
991+
batch = llama_batch_get_one(&newTokenId, 1);
992+
std::tie(newTokenId, isEnd) = gen(batch, banSampler);
993+
tokenCount = rewindTokenCount;
994+
continue;
995+
}
996+
case MatchTrie::MatchResult::MAYBE:
997+
break;
988998
}
989999
}
9901000

0 commit comments

Comments
 (0)