Skip to content

Commit 1e73dcb

Browse files
authoredFeb 15, 2025
feat: use softmax_F32 for sampler. (#163)
1 parent 4cda910 commit 1e73dcb

13 files changed

+106
-120
lines changed
 

‎Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ llm.o: src/llm.cpp
5555
$(CXX) $(CXXFLAGS) -c $^ -o $@
5656
app.o: src/app.cpp
5757
$(CXX) $(CXXFLAGS) -c $^ -o $@
58-
tokenizer-test: src/tokenizer-test.cpp tokenizer.o
58+
tokenizer-test: src/tokenizer-test.cpp nn-quants.o nn-core.o llamafile-sgemm.o nn-cpu-ops.o tokenizer.o
5959
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
6060
dllama: src/dllama.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o
6161
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)

‎src/app.cpp

-9
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,6 @@ AppCliArgs::~AppCliArgs() {
119119
delete[] workerPorts;
120120
}
121121

122-
Timer::Timer() {
123-
startTime = std::chrono::high_resolution_clock::now();
124-
}
125-
126-
NnSize Timer::elapsed() {
127-
auto endTime = std::chrono::high_resolution_clock::now();
128-
return (NnSize)std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
129-
}
130-
131122
RootLlmInference::RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
132123
this->header = net->header;
133124
this->tokenPipe = (float *)execution->pipes[net->tokenPipeIndex];

‎src/app.hpp

-10
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,6 @@ class AppCliArgs {
3737
~AppCliArgs();
3838
};
3939

40-
41-
class Timer {
42-
private:
43-
std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
44-
public:
45-
Timer();
46-
NnSize elapsed();
47-
};
48-
49-
5040
typedef struct {
5141
NnSize position;
5242
NnSize batchSize; // 0 = stop signal

‎src/dllama-api.cpp

+15-15
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class HttpRequest {
187187
}
188188

189189
void writeNotFound() {
190-
const char* data = "HTTP/1.1 404 Not Found\r\n";
190+
const char *data = "HTTP/1.1 404 Not Found\r\n";
191191
writeSocket(serverSocket, data, strlen(data));
192192
}
193193

@@ -218,7 +218,7 @@ class HttpRequest {
218218
}
219219

220220
void writeStreamEndChunk() {
221-
const char* endChunk = "0000\r\n\r\n";
221+
const char *endChunk = "0000\r\n\r\n";
222222
writeSocket(serverSocket, endChunk, strlen(endChunk));
223223
}
224224
};
@@ -310,24 +310,24 @@ class NaiveCache {
310310

311311
class ApiServer {
312312
private:
313-
RootLlmInference* inference;
314-
Tokenizer* tokenizer;
315-
Sampler* sampler;
316-
AppCliArgs* args;
317-
LlmHeader* header;
318-
EosDetector* eosDetector;
319-
ChatTemplate* chatTemplate;
313+
RootLlmInference *inference;
314+
Tokenizer *tokenizer;
315+
Sampler *sampler;
316+
AppCliArgs *args;
317+
LlmHeader *header;
318+
EosDetector *eosDetector;
319+
ChatTemplateGenerator *templateGenerator;
320320
NaiveCache naiveCache;
321321

322322
public:
323-
ApiServer( RootLlmInference* inference, Tokenizer* tokenizer, Sampler* sampler, AppCliArgs* args, LlmHeader* header, EosDetector* eosDetector, ChatTemplate* chatTemplate) {
323+
ApiServer(RootLlmInference *inference, Tokenizer *tokenizer, Sampler *sampler, AppCliArgs *args, LlmHeader *header, EosDetector *eosDetector, ChatTemplateGenerator *templateGenerator) {
324324
this->inference = inference;
325325
this->tokenizer = tokenizer;
326326
this->sampler = sampler;
327327
this->args = args;
328328
this->header = header;
329329
this->eosDetector = eosDetector;
330-
this->chatTemplate = chatTemplate;
330+
this->templateGenerator = templateGenerator;
331331
}
332332

333333
void complete(HttpRequest& request) {
@@ -345,7 +345,7 @@ class ApiServer {
345345
inputItems[i].message = deltaPrompt[i].content;
346346
}
347347

348-
GeneratedChat inputPrompt = chatTemplate->generate(nInputItems, inputItems, true);
348+
GeneratedChat inputPrompt = templateGenerator->generate(nInputItems, inputItems, true);
349349
printf("🔹%s🔸", inputPrompt.content);
350350

351351
int nPromptTokens;
@@ -484,7 +484,7 @@ class ApiServer {
484484
}
485485
};
486486

487-
void handleCompletionsRequest(HttpRequest& request, ApiServer* api) {
487+
void handleCompletionsRequest(HttpRequest& request, ApiServer *api) {
488488
api->complete(request);
489489
}
490490

@@ -500,9 +500,9 @@ static void server(AppInferenceContext *context) {
500500
int serverSocket = createServerSocket(context->args->port);
501501

502502
TokenizerChatStops stops(context->tokenizer);
503-
ChatTemplate chatTemplate(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
503+
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
504504
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
505-
ApiServer api(context->inference, context->tokenizer, context->sampler, context->args, context->header, &eosDetector, &chatTemplate);
505+
ApiServer api(context->inference, context->tokenizer, context->sampler, context->args, context->header, &eosDetector, &templateGenerator);
506506

507507
printf("Server URL: http://127.0.0.1:%d/v1/\n", context->args->port);
508508

‎src/dllama.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ static void inference(AppInferenceContext *context) {
5252
if (context->network != nullptr)
5353
context->network->getStats(&sentBytes, &recvBytes);
5454
printf("🔷️ E%5u ms S%6zu kB R%6zu kB (%d tokens)\n",
55-
batchTimer.elapsed(),
55+
batchTimer.elapsedMiliseconds(),
5656
sentBytes / 1024,
5757
recvBytes / 1024,
5858
batchSize);
5959
}
60-
NnSize evalTime = evalTimer.elapsed();
60+
NnSize evalTime = evalTimer.elapsedMiliseconds();
6161

6262
fflush(stdout);
6363

@@ -80,13 +80,13 @@ static void inference(AppInferenceContext *context) {
8080
context->network->getStats(&sentBytes, &recvBytes);
8181

8282
printf("🔶 P%5u ms S%6zu kB R%6zu kB %s\n",
83-
tokenTimer.elapsed(),
83+
tokenTimer.elapsedMiliseconds(),
8484
sentBytes / 1024,
8585
recvBytes / 1024,
8686
piece == nullptr ? "~" : piece);
8787
fflush(stdout);
8888
}
89-
NnSize predTime = predTimer.elapsed();
89+
NnSize predTime = predTimer.elapsedMiliseconds();
9090

9191
NnSize nEvalTokens = nInputTokens - 1;
9292
NnSize nPredTokens = pos - nEvalTokens;
@@ -123,7 +123,7 @@ static void chat(AppInferenceContext *context) {
123123
char prompt[2048];
124124

125125
TokenizerChatStops stops(context->tokenizer);
126-
ChatTemplate chatTemplate(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
126+
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
127127
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
128128

129129
const size_t sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt));
@@ -142,7 +142,7 @@ static void chat(AppInferenceContext *context) {
142142

143143
deltaItems.push_back(ChatItem{"user", prompt});
144144

145-
GeneratedChat inputPrompt = chatTemplate.generate(deltaItems.size(), deltaItems.data(), true);
145+
GeneratedChat inputPrompt = templateGenerator.generate(deltaItems.size(), deltaItems.data(), true);
146146
std::unique_ptr<int[]> inputTokensPtr(new int[inputPrompt.length + 2]);
147147
int *inputTokens = inputTokensPtr.get();
148148

‎src/nn/nn-core.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,20 @@ void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
165165
printf("📀 RequiredMemory: %lu kB\n", total / 1024);
166166
}
167167

168+
Timer::Timer() {
169+
startTime = std::chrono::high_resolution_clock::now();
170+
}
171+
172+
NnSize Timer::elapsedMiliseconds() {
173+
auto endTime = std::chrono::high_resolution_clock::now();
174+
return (NnSize)std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
175+
}
176+
177+
NnSize Timer::elapsedMicroseconds() {
178+
auto endTime = std::chrono::high_resolution_clock::now();
179+
return (NnSize)std::chrono::duration_cast<std::chrono::microseconds>(endTime - startTime).count();
180+
}
181+
168182
// slicers
169183

170184
NnKvCacheSlice sliceKvCache(NnSize kvDim, NnSize seqLen, NnSize nNodes) {

‎src/nn/nn-core.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef NN_CORE_H
22
#define NN_CORE_H
33

4+
#include <chrono>
45
#include <list>
56
#include <memory>
67
#include <cstdint>
@@ -262,6 +263,15 @@ void releaseNodeConfig(NnNodeConfig *nodeConfig);
262263

263264
void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig);
264265

266+
class Timer {
267+
private:
268+
std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
269+
public:
270+
Timer();
271+
NnSize elapsedMiliseconds();
272+
NnSize elapsedMicroseconds();
273+
};
274+
265275
// slicers
266276

267277
NnKvCacheSlice sliceKvCache(NnSize kvDim, NnSize seqLen, NnSize nNodes);

‎src/nn/nn-cpu-ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ static void add_Q80_F32(float *y, const NnBlockQ80 *x, const NnSize n, const NnS
586586
#endif
587587
}
588588

589-
static void softmax_F32(float *x, const NnSize size) {
589+
void softmax_F32(float *x, const NnSize size) {
590590
if (size == 0)
591591
return;
592592

‎src/nn/nn-cpu-ops.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,6 @@ void printCpuInstructionSet();
3838
NnCpuOpForwardInit getCpuOpForwardInit(NnOpCode code, NnOpQuantType quantType);
3939
NnCpuOpForward getCpuOpForward(NnOpCode code, NnOpQuantType quantType);
4040

41+
void softmax_F32(float *x, const NnSize size);
42+
4143
#endif

‎src/nn/nn-executor.cpp

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#include <cassert>
2-
#include <chrono>
32
#include <cstring>
43
#include <stdexcept>
54
#include "nn-executor.hpp"
65

7-
#define DEBUG_BENCHMARK false
6+
#define DEBUG_EXECUTOR_BENCHMARK false
87

98
void NnFakeNodeSynchronizer::sync(NnSize segmentIndex, NnSize nThreads, NnSize threadIndex) {
109
// Nothing
@@ -98,10 +97,10 @@ void NnExecutor::loadWeight(const char *name, NnSize index, NnSize nBytes, NnByt
9897
}
9998

10099
inline void executeStep(NnExecutorStep *step, NnSize nThreads, NnExecutorThread *thread, NnExecutorContext *context) {
101-
#if DEBUG_BENCHMARK
100+
#if DEBUG_EXECUTOR_BENCHMARK
102101
assert(nThreads == 1);
103-
auto startTime = std::chrono::high_resolution_clock::now();
104-
#endif
102+
Timer startTime;
103+
#endif
105104

106105
if (step->type == STEP_EXECUTE_OP) {
107106
step->segment->forward(step->arg0, nThreads, thread->threadIndex, context->batchSize);
@@ -114,14 +113,13 @@ inline void executeStep(NnExecutorStep *step, NnSize nThreads, NnExecutorThread
114113
throw std::invalid_argument("Unsupported step type");
115114
}
116115

117-
#if DEBUG_BENCHMARK
118-
auto endTime = std::chrono::high_resolution_clock::now();
119-
NnSize duration = (NnSize)std::chrono::duration_cast<std::chrono::microseconds>(endTime - startTime).count();
116+
#if DEBUG_EXECUTOR_BENCHMARK
117+
NnSize duration = startTime.elapsedMicroseconds();
120118
if (step->type == STEP_EXECUTE_OP)
121119
printf("🕒 [OP %16s %2d] %u μs\n", opCodeToString(step->opConfig->code), step->opConfig->index, duration);
122120
else if (step->type == STEP_SYNC_NODES)
123121
printf("🕒 [SYNC %17d] %u μs\n", step->arg0, duration);
124-
#endif
122+
#endif
125123
}
126124

127125
static inline void *executorThreadHandler(void *arg) {

‎src/tokenizer-test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ void dev_testDecoderEmojiWithEos(Tokenizer *tokenizer) {
101101
}
102102

103103
void testChatTemplateDetection() {
104-
ChatTemplate t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "<eos>");
104+
ChatTemplateGenerator t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "<eos>");
105105
assert(t0.type == TEMPLATE_LLAMA3);
106106

107107
printOk("chatTemplateDetection");

0 commit comments

Comments
 (0)