@@ -16,7 +16,7 @@ static void inference(AppInferenceContext *context) {
16
16
std::vector<int > inputTokensVec (std::strlen (context->args ->prompt ) + 3 );
17
17
int *inputTokens = inputTokensVec.data ();
18
18
19
- NnSize pos = 0 ;
19
+ NnUint pos = 0 ;
20
20
int token;
21
21
int nInputTokens;
22
22
context->tokenizer ->encode (context->args ->prompt , inputTokens, &nInputTokens, true , false );
@@ -27,21 +27,21 @@ static void inference(AppInferenceContext *context) {
27
27
throw std::runtime_error (" The number of prompt tokens is greater than the number of steps" );
28
28
29
29
Timer evalTimer;
30
- size_t sentBytes = 0 ;
31
- size_t recvBytes = 0 ;
30
+ NnSize sentBytes = 0 ;
31
+ NnSize recvBytes = 0 ;
32
32
printf (" %s\n " , context->args ->prompt );
33
33
for (;;) {
34
34
Timer batchTimer;
35
35
long remainingTokens = nInputTokens - 1 - (long )pos;
36
36
if (remainingTokens <= 0 )
37
37
break ;
38
- NnSize batchSize = remainingTokens < context->args ->nBatches
38
+ NnUint batchSize = remainingTokens < context->args ->nBatches
39
39
? remainingTokens
40
40
: context->args ->nBatches ;
41
41
42
42
context->inference ->setBatchSize (batchSize);
43
43
context->inference ->setPosition (pos);
44
- for (NnSize i = 0 ; i < batchSize; i++)
44
+ for (NnUint i = 0 ; i < batchSize; i++)
45
45
context->inference ->setToken (i, inputTokens[pos + i]);
46
46
47
47
context->inference ->forward ();
@@ -57,15 +57,15 @@ static void inference(AppInferenceContext *context) {
57
57
recvBytes / 1024 ,
58
58
batchSize);
59
59
}
60
- NnSize evalTime = evalTimer.elapsedMiliseconds ();
60
+ NnUint evalTime = evalTimer.elapsedMiliseconds ();
61
61
62
62
fflush (stdout);
63
63
64
64
context->inference ->setBatchSize (1 );
65
65
context->tokenizer ->resetDecoder ();
66
66
67
67
Timer predTimer;
68
- const NnSize maxPos = std::min (context->header ->seqLen , context->args ->steps );
68
+ const NnUint maxPos = std::min (context->header ->seqLen , context->args ->steps );
69
69
for (; pos < maxPos; pos++) {
70
70
Timer tokenTimer;
71
71
context->inference ->setPosition (pos);
@@ -86,10 +86,10 @@ static void inference(AppInferenceContext *context) {
86
86
piece == nullptr ? " ~" : piece);
87
87
fflush (stdout);
88
88
}
89
- NnSize predTime = predTimer.elapsedMiliseconds ();
89
+ NnUint predTime = predTimer.elapsedMiliseconds ();
90
90
91
- NnSize nEvalTokens = nInputTokens - 1 ;
92
- NnSize nPredTokens = pos - nEvalTokens;
91
+ NnUint nEvalTokens = nInputTokens - 1 ;
92
+ NnUint nPredTokens = pos - nEvalTokens;
93
93
printf (" \n " );
94
94
printf (" Evaluation\n " );
95
95
printf (" nBatches: %d\n " , context->args ->nBatches );
@@ -104,11 +104,11 @@ static void inference(AppInferenceContext *context) {
104
104
predTime / ((float ) nPredTokens));
105
105
}
106
106
107
- static size_t readStdin (const char *guide, char *buffer, size_t size) {
107
+ static NnUint readStdin (const char *guide, char *buffer, NnUint size) {
108
108
std::fflush (stdin);
109
109
std::printf (" %s" , guide);
110
110
if (std::fgets (buffer, size, stdin) != NULL ) {
111
- size_t length = std::strlen (buffer);
111
+ NnUint length = std::strlen (buffer);
112
112
if (length > 0 && buffer[length - 1 ] == ' \n ' ) {
113
113
buffer[length - 1 ] = ' \0 ' ;
114
114
length--;
@@ -119,20 +119,20 @@ static size_t readStdin(const char *guide, char *buffer, size_t size) {
119
119
}
120
120
121
121
static void chat (AppInferenceContext *context) {
122
- const NnSize seqLen = context->header ->seqLen ;
122
+ const NnUint seqLen = context->header ->seqLen ;
123
123
char prompt[2048 ];
124
124
125
125
TokenizerChatStops stops (context->tokenizer );
126
126
ChatTemplateGenerator templateGenerator (context->args ->chatTemplateType , context->tokenizer ->chatTemplate , stops.stops [0 ]);
127
127
EosDetector eosDetector (stops.nStops , context->tokenizer ->eosTokenIds .data (), stops.stops , stops.maxStopLength , stops.maxStopLength );
128
128
129
- const size_t sysPromptLength = readStdin (" 💻 System prompt (optional): " , prompt, sizeof (prompt));
129
+ const NnUint sysPromptLength = readStdin (" 💻 System prompt (optional): " , prompt, sizeof (prompt));
130
130
std::vector<ChatItem> deltaItems;
131
131
if (sysPromptLength > 0 )
132
132
deltaItems.push_back (ChatItem{" system" , prompt});
133
133
134
- NnSize pos = 0 ;
135
- size_t userPromptLength;
134
+ NnUint pos = 0 ;
135
+ NnUint userPromptLength;
136
136
int token;
137
137
int nInputTokens;
138
138
do {
@@ -149,18 +149,18 @@ static void chat(AppInferenceContext *context) {
149
149
bool addBos = pos == 0 ;
150
150
context->tokenizer ->encode ((char *)inputPrompt.content , inputTokens, &nInputTokens, addBos, true );
151
151
152
- NnSize userPromptEndPos = (NnSize )std::min<unsigned int >(seqLen, pos + nInputTokens - 1 );
153
- for (NnSize i = 0 ; ;) {
152
+ NnUint userPromptEndPos = (NnUint )std::min<unsigned int >(seqLen, pos + nInputTokens - 1 );
153
+ for (NnUint i = 0 ; ;) {
154
154
int remainingTokens = userPromptEndPos - pos;
155
155
if (remainingTokens <= 0 )
156
156
break ;
157
- NnSize batchSize = remainingTokens < context->args ->nBatches
157
+ NnUint batchSize = remainingTokens < context->args ->nBatches
158
158
? remainingTokens
159
159
: context->args ->nBatches ;
160
160
161
161
context->inference ->setBatchSize (batchSize);
162
162
context->inference ->setPosition (pos);
163
- for (NnSize j = 0 ; j < batchSize; j++)
163
+ for (NnUint j = 0 ; j < batchSize; j++)
164
164
context->inference ->setToken (j, inputTokens[i + j]);
165
165
166
166
context->inference ->forward ();
0 commit comments