Skip to content

Commit f8113c1

Browse files
authoredFeb 18, 2025
fix: nnuint. (#174)
1 parent 24156d8 commit f8113c1

22 files changed

+622
-592
lines changed
 

‎src/app.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
7070

7171
args.nWorkers = count;
7272
args.workerHosts = new char*[count];
73-
args.workerPorts = new NnSize[count];
73+
args.workerPorts = new NnUint[count];
7474

7575
for (int s = 0; s < count; s++) {
7676
char *v = argv[i + 1 + s];
@@ -111,7 +111,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
111111

112112
AppCliArgs::~AppCliArgs() {
113113
if (workerHosts != nullptr) {
114-
for (NnSize i = 0; i < nWorkers; i++)
114+
for (NnUint i = 0; i < nWorkers; i++)
115115
delete[] workerHosts[i];
116116
delete[] workerHosts;
117117
}
@@ -130,21 +130,21 @@ RootLlmInference::RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution
130130
this->network = network; // May be nullptr!
131131
}
132132

133-
void RootLlmInference::setBatchSize(NnSize batchSize) {
133+
void RootLlmInference::setBatchSize(NnUint batchSize) {
134134
execution->setBatchSize(batchSize);
135135
controlPacket.batchSize = batchSize;
136136
}
137137

138-
void RootLlmInference::setPosition(NnSize position) {
138+
void RootLlmInference::setPosition(NnUint position) {
139139
assert(position >= 0);
140140
assert(position + execution->batchSize - 1 < header->seqLen);
141141

142142
controlPacket.position = position;
143-
for (NnSize i = 0; i < execution->batchSize; i++)
143+
for (NnUint i = 0; i < execution->batchSize; i++)
144144
positionPipe[i] = (float)(position + i);
145145
}
146146

147-
void RootLlmInference::setToken(NnSize batchIndex, NnSize token) {
147+
void RootLlmInference::setToken(NnUint batchIndex, NnUint token) {
148148
assert(batchIndex >= 0 && batchIndex < execution->batchSize);
149149
tokenPipe[batchIndex] = (float)token;
150150
}
@@ -179,14 +179,14 @@ bool WorkerLlmInference::tryReadControlPacket() {
179179
isFinished = true;
180180
return true;
181181
}
182-
for (NnSize i = 0; i < controlPacket.batchSize; i++)
182+
for (NnUint i = 0; i < controlPacket.batchSize; i++)
183183
positionPipe[i] = (float)(controlPacket.position + i);
184184
execution->setBatchSize(controlPacket.batchSize);
185185
return true;
186186
}
187187

188188
void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context)) {
189-
NnSize nNodes = args->nWorkers + 1;
189+
NnUint nNodes = args->nWorkers + 1;
190190

191191
LlmHeader header = loadLlmHeader(args->modelPath, args->maxSeqLen, args->syncType);
192192
if (nNodes > header.nKvHeads)

‎src/app.hpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,36 @@
1010
class AppCliArgs {
1111
public:
1212
char *mode;
13-
NnSize nThreads;
14-
NnSize nBatches;
13+
NnUint nThreads;
14+
NnUint nBatches;
1515
bool help;
1616

1717
// inference
1818
char *modelPath;
1919
char *tokenizerPath;
2020
char *prompt;
2121
NnFloatType syncType;
22-
NnSize nWorkers;
22+
NnUint nWorkers;
2323
char **workerHosts;
24-
NnSize *workerPorts;
24+
NnUint *workerPorts;
2525
float temperature;
2626
float topp;
27-
NnSize steps;
27+
NnUint steps;
2828
bool benchmark;
2929
unsigned long long seed;
3030
ChatTemplateType chatTemplateType;
31-
NnSize maxSeqLen;
31+
NnUint maxSeqLen;
3232

3333
// worker
34-
NnSize port;
34+
NnUint port;
3535

3636
static AppCliArgs parse(int argc, char **argv, bool hasMode);
3737
~AppCliArgs();
3838
};
3939

4040
typedef struct {
41-
NnSize position;
42-
NnSize batchSize; // 0 = stop signal
41+
NnUint position;
42+
NnUint batchSize; // 0 = stop signal
4343
} LlmControlPacket;
4444

4545
class RootLlmInference {
@@ -56,9 +56,9 @@ class RootLlmInference {
5656
LlmControlPacket controlPacket;
5757
public:
5858
RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network);
59-
void setBatchSize(NnSize batchSize);
60-
void setPosition(NnSize position);
61-
void setToken(NnSize batchIndex, NnSize token);
59+
void setBatchSize(NnUint batchSize);
60+
void setPosition(NnUint position);
61+
void setToken(NnUint batchIndex, NnUint token);
6262
void forward();
6363
void finish();
6464
};

‎src/dllama-api.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -380,20 +380,20 @@ class ApiServer {
380380
buffer += inputPrompt.publicPrompt;
381381
}
382382

383-
NnSize pos = startPos;
383+
NnUint pos = startPos;
384384
int token;
385-
for (NnSize i = 0; ;) {
385+
for (NnUint i = 0; ;) {
386386
long remainingTokens = promptEndPos - pos;
387387
if (remainingTokens <= 0)
388388
break;
389389

390-
NnSize batchSize = remainingTokens < args->nBatches
390+
NnUint batchSize = remainingTokens < args->nBatches
391391
? remainingTokens
392392
: args->nBatches;
393393

394394
inference->setBatchSize(batchSize);
395395
inference->setPosition(pos);
396-
for (NnSize j = 0; j < batchSize; j++)
396+
for (NnUint j = 0; j < batchSize; j++)
397397
inference->setToken(j, promptTokens[i + j]);
398398

399399
inference->forward();

‎src/dllama.cpp

+20-20
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ static void inference(AppInferenceContext *context) {
1616
std::vector<int> inputTokensVec(std::strlen(context->args->prompt) + 3);
1717
int *inputTokens = inputTokensVec.data();
1818

19-
NnSize pos = 0;
19+
NnUint pos = 0;
2020
int token;
2121
int nInputTokens;
2222
context->tokenizer->encode(context->args->prompt, inputTokens, &nInputTokens, true, false);
@@ -27,21 +27,21 @@ static void inference(AppInferenceContext *context) {
2727
throw std::runtime_error("The number of prompt tokens is greater than the number of steps");
2828

2929
Timer evalTimer;
30-
size_t sentBytes = 0;
31-
size_t recvBytes = 0;
30+
NnSize sentBytes = 0;
31+
NnSize recvBytes = 0;
3232
printf("%s\n", context->args->prompt);
3333
for (;;) {
3434
Timer batchTimer;
3535
long remainingTokens = nInputTokens - 1 - (long)pos;
3636
if (remainingTokens <= 0)
3737
break;
38-
NnSize batchSize = remainingTokens < context->args->nBatches
38+
NnUint batchSize = remainingTokens < context->args->nBatches
3939
? remainingTokens
4040
: context->args->nBatches;
4141

4242
context->inference->setBatchSize(batchSize);
4343
context->inference->setPosition(pos);
44-
for (NnSize i = 0; i < batchSize; i++)
44+
for (NnUint i = 0; i < batchSize; i++)
4545
context->inference->setToken(i, inputTokens[pos + i]);
4646

4747
context->inference->forward();
@@ -57,15 +57,15 @@ static void inference(AppInferenceContext *context) {
5757
recvBytes / 1024,
5858
batchSize);
5959
}
60-
NnSize evalTime = evalTimer.elapsedMiliseconds();
60+
NnUint evalTime = evalTimer.elapsedMiliseconds();
6161

6262
fflush(stdout);
6363

6464
context->inference->setBatchSize(1);
6565
context->tokenizer->resetDecoder();
6666

6767
Timer predTimer;
68-
const NnSize maxPos = std::min(context->header->seqLen, context->args->steps);
68+
const NnUint maxPos = std::min(context->header->seqLen, context->args->steps);
6969
for (; pos < maxPos; pos++) {
7070
Timer tokenTimer;
7171
context->inference->setPosition(pos);
@@ -86,10 +86,10 @@ static void inference(AppInferenceContext *context) {
8686
piece == nullptr ? "~" : piece);
8787
fflush(stdout);
8888
}
89-
NnSize predTime = predTimer.elapsedMiliseconds();
89+
NnUint predTime = predTimer.elapsedMiliseconds();
9090

91-
NnSize nEvalTokens = nInputTokens - 1;
92-
NnSize nPredTokens = pos - nEvalTokens;
91+
NnUint nEvalTokens = nInputTokens - 1;
92+
NnUint nPredTokens = pos - nEvalTokens;
9393
printf("\n");
9494
printf("Evaluation\n");
9595
printf(" nBatches: %d\n", context->args->nBatches);
@@ -104,11 +104,11 @@ static void inference(AppInferenceContext *context) {
104104
predTime / ((float) nPredTokens));
105105
}
106106

107-
static size_t readStdin(const char *guide, char *buffer, size_t size) {
107+
static NnUint readStdin(const char *guide, char *buffer, NnUint size) {
108108
std::fflush(stdin);
109109
std::printf("%s", guide);
110110
if (std::fgets(buffer, size, stdin) != NULL) {
111-
size_t length = std::strlen(buffer);
111+
NnUint length = std::strlen(buffer);
112112
if (length > 0 && buffer[length - 1] == '\n') {
113113
buffer[length - 1] = '\0';
114114
length--;
@@ -119,20 +119,20 @@ static size_t readStdin(const char *guide, char *buffer, size_t size) {
119119
}
120120

121121
static void chat(AppInferenceContext *context) {
122-
const NnSize seqLen = context->header->seqLen;
122+
const NnUint seqLen = context->header->seqLen;
123123
char prompt[2048];
124124

125125
TokenizerChatStops stops(context->tokenizer);
126126
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
127127
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
128128

129-
const size_t sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt));
129+
const NnUint sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt));
130130
std::vector<ChatItem> deltaItems;
131131
if (sysPromptLength > 0)
132132
deltaItems.push_back(ChatItem{"system", prompt});
133133

134-
NnSize pos = 0;
135-
size_t userPromptLength;
134+
NnUint pos = 0;
135+
NnUint userPromptLength;
136136
int token;
137137
int nInputTokens;
138138
do {
@@ -149,18 +149,18 @@ static void chat(AppInferenceContext *context) {
149149
bool addBos = pos == 0;
150150
context->tokenizer->encode((char*)inputPrompt.content, inputTokens, &nInputTokens, addBos, true);
151151

152-
NnSize userPromptEndPos = (NnSize)std::min<unsigned int>(seqLen, pos + nInputTokens - 1);
153-
for (NnSize i = 0; ;) {
152+
NnUint userPromptEndPos = (NnUint)std::min<unsigned int>(seqLen, pos + nInputTokens - 1);
153+
for (NnUint i = 0; ;) {
154154
int remainingTokens = userPromptEndPos - pos;
155155
if (remainingTokens <= 0)
156156
break;
157-
NnSize batchSize = remainingTokens < context->args->nBatches
157+
NnUint batchSize = remainingTokens < context->args->nBatches
158158
? remainingTokens
159159
: context->args->nBatches;
160160

161161
context->inference->setBatchSize(batchSize);
162162
context->inference->setPosition(pos);
163-
for (NnSize j = 0; j < batchSize; j++)
163+
for (NnUint j = 0; j < batchSize; j++)
164164
context->inference->setToken(j, inputTokens[i + j]);
165165

166166
context->inference->forward();

0 commit comments

Comments
 (0)
Failed to load comments.