|
8 | 8 | #include <vector>
|
9 | 9 | #include <string>
|
10 | 10 | #include <sstream>
|
| 11 | +#include <unordered_set> |
11 | 12 |
|
12 | 13 | // Static vector to hold previous generation tokens
|
13 | 14 | // TODO: Remove in continuous batch implementation
|
@@ -749,7 +750,9 @@ const char* InferToReadbackBuffer(
|
749 | 750 | const char** rewindStrings,
|
750 | 751 | const unsigned numRewindStrings,
|
751 | 752 | const char** stoppingStrings,
|
752 |
| - const unsigned numStoppingStrings) |
| 753 | + const unsigned numStoppingStrings, |
| 754 | + const unsigned* stoppingTokens, |
| 755 | + const unsigned numStoppingTokens) |
753 | 756 | {
|
754 | 757 | if (abortCallback != nullptr) {
|
755 | 758 | llama_set_abort_callback(context, abortCallback, nullptr);
|
@@ -832,6 +835,13 @@ const char* InferToReadbackBuffer(
|
832 | 835 | return nullptr;
|
833 | 836 | }
|
834 | 837 |
|
| 838 | + // Create stop token set |
| 839 | + std::unordered_set<unsigned> stopTokenSet; |
| 840 | + if (stoppingTokens != nullptr && numStoppingTokens > 0) { |
| 841 | + stopTokenSet.insert(stoppingTokens, stoppingTokens + numStoppingTokens); |
| 842 | + } |
| 843 | + |
| 844 | + // Populate string ban trie |
835 | 845 | MatchTrie::MatchTrie matchingTrie;
|
836 | 846 |
|
837 | 847 | if (rewindStrings != nullptr && numRewindStrings > 0) {
|
@@ -885,7 +895,7 @@ const char* InferToReadbackBuffer(
|
885 | 895 |
|
886 | 896 | while (true) {
|
887 | 897 | // Abort if callback is fired
|
888 |
| - if (isEnd) { |
| 898 | + if (isEnd || (!stopTokenSet.empty() && stopTokenSet.find(newTokenId) != stopTokenSet.end())) { |
889 | 899 | finishReason = "StopToken";
|
890 | 900 | stoppedAt = TokenToPiece(model, newTokenId, decodeSpecial).value_or("");
|
891 | 901 | break;
|
|
0 commit comments