Skip to content

Commit a8c7de8

Browse files
committed
Sampling: Add stop tokens
This was an oversight in the initial release. Now stop tokens can be provided as an integer value. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
1 parent 7392310 commit a8c7de8

File tree

5 files changed

+32
-7
lines changed

5 files changed

+32
-7
lines changed

bindings/binding.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99
#include <string>
1010
#include <sstream>
11+
#include <unordered_set>
1112

1213
// Static vector to hold previous generation tokens
1314
// TODO: Remove in continuous batch implementation
@@ -749,7 +750,9 @@ const char* InferToReadbackBuffer(
749750
const char** rewindStrings,
750751
const unsigned numRewindStrings,
751752
const char** stoppingStrings,
752-
const unsigned numStoppingStrings)
753+
const unsigned numStoppingStrings,
754+
const unsigned* stoppingTokens,
755+
const unsigned numStoppingTokens)
753756
{
754757
if (abortCallback != nullptr) {
755758
llama_set_abort_callback(context, abortCallback, nullptr);
@@ -832,6 +835,13 @@ const char* InferToReadbackBuffer(
832835
return nullptr;
833836
}
834837

838+
// Create stop token set
839+
std::unordered_set<unsigned> stopTokenSet;
840+
if (stoppingTokens != nullptr && numStoppingTokens > 0) {
841+
stopTokenSet.insert(stoppingTokens, stoppingTokens + numStoppingTokens);
842+
}
843+
844+
// Populate string ban trie
835845
MatchTrie::MatchTrie matchingTrie;
836846

837847
if (rewindStrings != nullptr && numRewindStrings > 0) {
@@ -885,7 +895,7 @@ const char* InferToReadbackBuffer(
885895

886896
while (true) {
887897
// Abort if callback is fired
888-
if (isEnd) {
898+
if (isEnd || (!stopTokenSet.empty() && stopTokenSet.find(newTokenId) != stopTokenSet.end())) {
889899
finishReason = "StopToken";
890900
stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value_or("");
891901
break;

bindings/binding.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ extern "C" {
110110
const char** rewindStrings,
111111
unsigned numRewindStrings,
112112
const char** stoppingStrings,
113-
unsigned numStoppingStrings);
113+
unsigned numStoppingStrings,
114+
const unsigned* stoppingTokens,
115+
unsigned numStoppingTokens);
114116

115117
#ifdef __cplusplus
116118
}

bindings/bindings.ts

+14-3
Original file line numberDiff line numberDiff line change
@@ -719,9 +719,18 @@ export class Model {
719719

720720
const promptPtr = new TextEncoder().encode(prompt + "\0");
721721

722+
// These are numbers and strings, TS doesn't understand for some reason
723+
const stopTokens = params.stop.filter((e) =>
724+
typeof e === "number"
725+
) as number[];
726+
const stopStrings = params.stop.filter((e) =>
727+
typeof e === "string"
728+
) as string[];
729+
722730
// Use the helper function for both arrays
723731
const rewindPtrArray = pointerArrayFromStrings(params.banned_strings);
724-
const stopPtrArray = pointerArrayFromStrings(params.stop);
732+
const stopTokensPtr = new Uint32Array(stopTokens);
733+
const stopStringsPtr = pointerArrayFromStrings(stopStrings);
725734

726735
const promptBosToken = params.add_bos_token
727736
? this.tokenizer.bosToken?.piece
@@ -754,8 +763,10 @@ export class Model {
754763
seed,
755764
rewindPtrArray.inner,
756765
params.banned_strings.length,
757-
stopPtrArray.inner,
758-
params.stop.length,
766+
stopStringsPtr.inner,
767+
stopStrings.length,
768+
stopTokensPtr,
769+
stopTokens.length,
759770
);
760771

761772
// Read from the read buffer

bindings/symbols.ts

+2
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ export default {
196196
"u32", // count of rewind strings
197197
"buffer", // const char** stoppingStrings
198198
"u32", // count of stop strings
199+
"buffer", // const unsigned* stoppingTokens
200+
"u32", // count of stop tokens
199201
],
200202
result: "pointer" as const,
201203
nonblocking: true,

common/sampling.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ const GenerationOptionsSchema = z.aliasedObject(
88
}),
99
stop: z.union([
1010
z.string().transform((str) => [str]),
11-
z.array(z.string()),
11+
z.array(z.union([z.string(), z.number()])),
1212
])
1313
.nullish().coalesce([])
1414
.openapi({

0 commit comments

Comments
 (0)