Skip to content

Commit 31ff8f4

Browse files
authored
feat: vulkan. (#176)
1 parent ec2cb7f commit 31ff8f4

34 files changed

+2877
-267
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ run*.sh
1515
server
1616
/dllama
1717
/dllama-*
18-
*.exe
18+
*.exe
19+
*.spv

Makefile

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
CXX = g++
2-
CXXFLAGS = -std=c++11 -Werror -Wformat -Werror=format-security
2+
CXXFLAGS = -std=c++11 -Werror -Wformat -Werror=format-security
33

44
ifndef TERMUX_VERSION
55
CXXFLAGS += -march=native -mtune=native
@@ -15,11 +15,25 @@ ifdef WVLA
1515
CXXFLAGS += -Wvla-extension
1616
endif
1717

18+
ifdef DLLAMA_VULKAN
19+
CGLSLC = glslc
20+
21+
ifeq ($(OS),Windows_NT)
22+
LIBS += -L$(VK_SDK_PATH)\lib -lvulkan-1
23+
CXXFLAGS += -DDLLAMA_VULKAN -I$(VK_SDK_PATH)\include
24+
else
25+
LIBS += -lvulkan
26+
CXXFLAGS += -DDLLAMA_VULKAN
27+
endif
28+
29+
DEPS += nn-vulkan.o
30+
endif
31+
1832
ifeq ($(OS),Windows_NT)
19-
LIBS = -lws2_32
33+
LIBS += -lws2_32
2034
DELETE_CMD = del /f
2135
else
22-
LIBS = -lpthread
36+
LIBS += -lpthread
2337
DELETE_CMD = rm -fv
2438
endif
2539

@@ -47,6 +61,19 @@ nn-cpu-test: src/nn/nn-cpu-test.cpp nn-quants.o nn-core.o nn-executor.o llamafil
4761
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
4862
nn-cpu-ops-test: src/nn/nn-cpu-ops-test.cpp nn-quants.o nn-core.o nn-executor.o llamafile-sgemm.o nn-cpu.o
4963
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
64+
nn-vulkan.o: src/nn/nn-vulkan.cpp
65+
$(CXX) $(CXXFLAGS) -c $^ -o $@
66+
67+
ifdef DLLAMA_VULKAN
68+
VULKAN_SHADER_SRCS := $(wildcard src/nn/vulkan/*.comp)
69+
VULKAN_SHADER_BINS := $(VULKAN_SHADER_SRCS:.comp=.spv)
70+
DEPS += $(VULKAN_SHADER_BINS)
71+
72+
%.spv: %.comp
73+
$(CGLSLC) -c $< -o $@
74+
nn-vulkan-test: src/nn/nn-vulkan-test.cpp nn-quants.o nn-core.o nn-executor.o nn-vulkan.o ${DEPS}
75+
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)
76+
endif
5077

5178
# llm
5279
tokenizer.o: src/tokenizer.cpp
@@ -57,7 +84,7 @@ app.o: src/app.cpp
5784
$(CXX) $(CXXFLAGS) -c $^ -o $@
5885
tokenizer-test: src/tokenizer-test.cpp nn-quants.o nn-core.o llamafile-sgemm.o nn-cpu-ops.o tokenizer.o
5986
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
60-
dllama: src/dllama.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o
61-
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
62-
dllama-api: src/dllama-api.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o
63-
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
87+
dllama: src/dllama.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o ${DEPS}
88+
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)
89+
dllama-api: src/dllama-api.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o ${DEPS}
90+
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)

src/app.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include <cassert>
33
#include <cstring>
44
#include <stdexcept>
5+
#if defined(DLLAMA_VULKAN)
6+
#include "nn/nn-vulkan.hpp"
7+
#endif
58

69
static NnFloatType parseFloatType(char *val) {
710
if (std::strcmp(val, "f32") == 0) return F_32;
@@ -38,6 +41,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
3841
args.seed = (unsigned long long)time(nullptr);
3942
args.chatTemplateType = TEMPLATE_UNKNOWN;
4043
args.maxSeqLen = 0;
44+
args.gpuIndex = -1;
4145
int i = 1;
4246
if (requireMode && argc > 1) {
4347
args.mode = argv[1];
@@ -102,6 +106,8 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
102106
args.chatTemplateType = parseChatTemplateType(value);
103107
} else if (std::strcmp(name, "--max-seq-len") == 0) {
104108
args.maxSeqLen = (unsigned int)atoi(value);
109+
} else if (std::strcmp(name, "--gpu-index") == 0) {
110+
args.gpuIndex = atoi(value);
105111
} else {
106112
throw std::runtime_error("Unknown option: " + std::string(name));
107113
}
@@ -119,6 +125,17 @@ AppCliArgs::~AppCliArgs() {
119125
delete[] workerPorts;
120126
}
121127

128+
static NnDevice *createDevice(AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
129+
if (args->gpuIndex >= 0) {
130+
#if defined(DLLAMA_VULKAN)
131+
return new NnVulkanDevice(args->gpuIndex, netConfig, nodeConfig, netExecution);
132+
#else
133+
throw std::runtime_error("This build does not support GPU");
134+
#endif
135+
}
136+
return new NnCpuDevice(netConfig, nodeConfig, netExecution);
137+
}
138+
122139
RootLlmInference::RootLlmInference(LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
123140
this->header = net->header;
124141
this->tokenPipe = (float *)execution->pipes[net->tokenPipeIndex];
@@ -152,7 +169,6 @@ void RootLlmInference::setToken(NnUint batchIndex, NnUint token) {
152169
void RootLlmInference::forward() {
153170
if (network != nullptr)
154171
network->writeAll(&controlPacket, sizeof(LlmControlPacket));
155-
device->syncPointers();
156172
executor->forward();
157173
}
158174

@@ -226,13 +242,13 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
226242
configWriter.writeToWorkers(&net.netConfig, net.nodeConfigs);
227243
}
228244

229-
NnCpuDevice cpu(&net.netConfig, rootNodeConfig, &execution);
230-
NnExecutor executor(&net.netConfig, rootNodeConfig, &cpu, &execution, synchronizer.get(), args->benchmark);
245+
std::unique_ptr<NnDevice> device(createDevice(args, &net.netConfig, rootNodeConfig, &execution));
246+
NnExecutor executor(&net.netConfig, rootNodeConfig, device.get(), &execution, synchronizer.get(), args->benchmark);
231247

232248
NnRootWeightLoader weightLoader(&executor, network, nNodes);
233249
loadLlmNetWeight(args->modelPath, &net, &weightLoader);
234250

235-
RootLlmInference inference(&net, &cpu, &execution, &executor, network);
251+
RootLlmInference inference(&net, device.get(), &execution, &executor, network);
236252

237253
if (network != nullptr) {
238254
network->resetStats();
@@ -268,9 +284,10 @@ void runWorkerApp(AppCliArgs *args) {
268284

269285
NnNetExecution execution(args->nThreads, &netConfig);
270286

287+
std::unique_ptr<NnDevice> device(createDevice(args, &netConfig, &nodeConfig, &execution));
288+
271289
NnNetworkNodeSynchronizer synchronizer(network, &execution, &netConfig, &nodeConfig);
272-
NnCpuDevice cpu(&netConfig, &nodeConfig, &execution);
273-
NnExecutor executor(&netConfig, &nodeConfig, &cpu, &execution, &synchronizer, false);
290+
NnExecutor executor(&netConfig, &nodeConfig, device.get(), &execution, &synchronizer, false);
274291

275292
NnWorkerWeightReader weightReader(&executor, network);
276293
weightReader.read();

src/app.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class AppCliArgs {
2929
unsigned long long seed;
3030
ChatTemplateType chatTemplateType;
3131
NnUint maxSeqLen;
32+
int gpuIndex;
3233

3334
// worker
3435
NnUint port;

src/dllama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ static void inference(AppInferenceContext *context) {
5454
if (context->network != nullptr)
5555
context->network->getStats(&sentBytes, &recvBytes);
5656

57-
NnUint evalTime = context->executor->getTotalTime(STEP_EXECUTE_OP) + context->executor->getTotalTime(STEP_SYNC_POINTERS);
57+
NnUint evalTime = context->executor->getTotalTime(STEP_EXECUTE_OP);
5858
NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES);
5959
printf("🔷️ Eval%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | (%d tokens)\n",
6060
evalTime / 1000,
@@ -83,7 +83,7 @@ static void inference(AppInferenceContext *context) {
8383
if (context->network != nullptr)
8484
context->network->getStats(&sentBytes, &recvBytes);
8585

86-
NnUint predTime = context->executor->getTotalTime(STEP_EXECUTE_OP) + context->executor->getTotalTime(STEP_SYNC_POINTERS);
86+
NnUint predTime = context->executor->getTotalTime(STEP_EXECUTE_OP);
8787
NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES);
8888
printf("🔶 Pred%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | %s\n",
8989
predTime / 1000,

0 commit comments

Comments
 (0)