Skip to content

Commit a3c4265

Browse files
committed
Bindings: Add guards to generate and finish if prompt encoding fails
These guards and fallbacks should prevent segfaults during generation. In addition, convert the finish reason check to switch/case for better readability. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
1 parent 647d15d commit a3c4265

File tree

3 files changed

+43
-22
lines changed

3 files changed

+43
-22
lines changed

bindings/binding.cpp

+25-9
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,15 @@ const char* InferToReadbackBuffer(
592592
// Tokenize and determine the amount of tokens to generate
593593
// addSpecial - Special tokens in this case are BOS tokens
594594
// parseSpecial is always true since special tokens should be parsed
595-
auto promptTokens = Tokenize(model, prompt, addSpecial, true).value();
595+
auto promptTokenResult = Tokenize(model, prompt, addSpecial, true);
596+
if (!promptTokenResult) {
597+
finishReason = "TokenEncode";
598+
readbackBufferPtr->jsonOutputBuffer = strdup(MakeJsonOutputString(context, finishReason, stoppedAt).c_str());
599+
readbackBufferPtr->done = true;
600+
return nullptr;
601+
}
602+
603+
auto promptTokens = promptTokenResult.value();
596604

597605
// Process the prompt in chunked batches
598606
if (!processPromptBatches(promptTokens)) {
@@ -602,8 +610,14 @@ const char* InferToReadbackBuffer(
602610
}
603611

604612
MatchTrie::MatchTrie matchingTrie;
605-
matchingTrie.AddMatchableWords(rewindStrings, numRewindStrings, MatchTrie::MatchType::REWIND);
606-
matchingTrie.AddMatchableWords(stoppingStrings, numStoppingStrings, MatchTrie::MatchType::STOP);
613+
614+
if (rewindStrings != nullptr && numRewindStrings > 0) {
615+
matchingTrie.AddMatchableWords(rewindStrings, numRewindStrings, MatchTrie::MatchType::REWIND);
616+
}
617+
618+
if (stoppingStrings != nullptr && numStoppingStrings > 0) {
619+
matchingTrie.AddMatchableWords(stoppingStrings, numStoppingStrings, MatchTrie::MatchType::STOP);
620+
}
607621

608622
std::string response;
609623
std::string buffer;
@@ -643,18 +657,18 @@ const char* InferToReadbackBuffer(
643657
// Abort if callback is fired
644658
if (isEnd) {
645659
finishReason = "StopToken";
646-
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value();
660+
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value_or("");
647661
break;
648662
}
649663

650664
// End on length if max tokens is exceeded
651665
if (tokenCount + batch.n_tokens > numberTokensToPredict) {
652666
finishReason = "MaxNewTokens";
653-
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value();
667+
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value_or("");
654668
break;
655669
}
656670

657-
const std::string piece = TokenToPiece(model, newTokenId, decodeSpecial).value();
671+
auto piece = TokenToPiece(model, newTokenId, decodeSpecial).value_or("");
658672

659673
buffer += piece;
660674
tokenCount += batch.n_tokens;
@@ -685,15 +699,17 @@ const char* InferToReadbackBuffer(
685699
WriteToReadbackBuffer(readbackBufferPtr, strdup(partialBuffer.c_str()), newTokenId);
686700
response += buffer;
687701

688-
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value();
702+
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value_or("");
689703
finishReason = "StopString";
690704
break;
691705
} else if (matchInfo.result == MatchTrie::MatchResult::MATCHED_REWIND) {
692706
llama_kv_cache_seq_rm(context, 0, rewindPos, -1);
693707

694708
const auto tokens = Tokenize(model, buffer, false, false);
695-
for (const llama_token token : tokens.value()) {
696-
biases.push_back({token, -50000.0f});
709+
if (tokens) {
710+
for (const llama_token token : tokens.value()) {
711+
biases.push_back({token, -50000.0f});
712+
}
697713
}
698714

699715
if (banSampler == nullptr) {

bindings/bindings.ts

+17-13
Original file line numberDiff line numberDiff line change
@@ -763,19 +763,23 @@ export class Model {
763763

764764
const finishResponse = await this.readbackBuffer.readJsonStatus();
765765
if (finishResponse) {
766-
if (
767-
finishResponse.finishReason == BindingFinishReason.CtxExceeded
768-
) {
769-
throw new Error(
770-
`Prompt exceeds max context length of ${this.maxSeqLen}`,
771-
);
772-
} else if (
773-
finishResponse.finishReason == BindingFinishReason.BatchDecode
774-
) {
775-
throw new Error(
776-
"Internal generation state is broken due to llama_decode error. " +
777-
"Please restart the server.",
778-
);
766+
switch (finishResponse.finishReason) {
767+
case BindingFinishReason.CtxExceeded:
768+
throw new Error(
769+
`Prompt exceeds max context length of ${this.maxSeqLen}`,
770+
);
771+
772+
case BindingFinishReason.BatchDecode:
773+
throw new Error(
774+
"Internal generation state is broken due to llama_decode error. " +
775+
"Please restart the server.",
776+
);
777+
778+
case BindingFinishReason.TokenEncode:
779+
throw new Error(
780+
"Could not tokenize the provided prompt. " +
781+
"Please make sure your prompt is formatted correctly.",
782+
);
779783
}
780784

781785
const totalTime = finishResponse.promptSec + finishResponse.genSec;

bindings/types.ts

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export enum BindingFinishReason {
1717
StopToken = "StopToken",
1818
MaxNewTokens = "MaxNewTokens",
1919
StopString = "StopString",
20+
TokenEncode = "TokenEncode",
2021
}
2122

2223
export type GenerationChunk = StreamChunk | FinishChunk;

0 commit comments

Comments
 (0)