2
2
#include < cassert>
3
3
#include < cstring>
4
4
#include < stdexcept>
5
+ #if defined(DLLAMA_VULKAN)
6
+ #include " nn/nn-vulkan.hpp"
7
+ #endif
5
8
6
9
static NnFloatType parseFloatType (char *val) {
7
10
if (std::strcmp (val, " f32" ) == 0 ) return F_32;
@@ -38,6 +41,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
38
41
args.seed = (unsigned long long )time (nullptr );
39
42
args.chatTemplateType = TEMPLATE_UNKNOWN;
40
43
args.maxSeqLen = 0 ;
44
+ args.gpuIndex = -1 ;
41
45
int i = 1 ;
42
46
if (requireMode && argc > 1 ) {
43
47
args.mode = argv[1 ];
@@ -102,6 +106,8 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
102
106
args.chatTemplateType = parseChatTemplateType (value);
103
107
} else if (std::strcmp (name, " --max-seq-len" ) == 0 ) {
104
108
args.maxSeqLen = (unsigned int )atoi (value);
109
+ } else if (std::strcmp (name, " --gpu-index" ) == 0 ) {
110
+ args.gpuIndex = atoi (value);
105
111
} else {
106
112
throw std::runtime_error (" Unknown option: " + std::string (name));
107
113
}
@@ -119,6 +125,17 @@ AppCliArgs::~AppCliArgs() {
119
125
delete[] workerPorts;
120
126
}
121
127
128
+ static NnDevice *createDevice (AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
129
+ if (args->gpuIndex >= 0 ) {
130
+ #if defined(DLLAMA_VULKAN)
131
+ return new NnVulkanDevice (args->gpuIndex , netConfig, nodeConfig, netExecution);
132
+ #else
133
+ throw std::runtime_error (" This build does not support GPU" );
134
+ #endif
135
+ }
136
+ return new NnCpuDevice (netConfig, nodeConfig, netExecution);
137
+ }
138
+
122
139
RootLlmInference::RootLlmInference (LlmNet *net, NnDevice *device, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
123
140
this ->header = net->header ;
124
141
this ->tokenPipe = (float *)execution->pipes [net->tokenPipeIndex ];
@@ -152,7 +169,6 @@ void RootLlmInference::setToken(NnUint batchIndex, NnUint token) {
152
169
void RootLlmInference::forward () {
153
170
if (network != nullptr )
154
171
network->writeAll (&controlPacket, sizeof (LlmControlPacket));
155
- device->syncPointers ();
156
172
executor->forward ();
157
173
}
158
174
@@ -226,13 +242,13 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
226
242
configWriter.writeToWorkers (&net.netConfig , net.nodeConfigs );
227
243
}
228
244
229
- NnCpuDevice cpu ( &net.netConfig , rootNodeConfig, &execution);
230
- NnExecutor executor (&net.netConfig , rootNodeConfig, &cpu , &execution, synchronizer.get (), args->benchmark );
245
+ std::unique_ptr<NnDevice> device ( createDevice (args, &net.netConfig , rootNodeConfig, &execution) );
246
+ NnExecutor executor (&net.netConfig , rootNodeConfig, device. get () , &execution, synchronizer.get (), args->benchmark );
231
247
232
248
NnRootWeightLoader weightLoader (&executor, network, nNodes);
233
249
loadLlmNetWeight (args->modelPath , &net, &weightLoader);
234
250
235
- RootLlmInference inference (&net, &cpu , &execution, &executor, network);
251
+ RootLlmInference inference (&net, device. get () , &execution, &executor, network);
236
252
237
253
if (network != nullptr ) {
238
254
network->resetStats ();
@@ -268,9 +284,10 @@ void runWorkerApp(AppCliArgs *args) {
268
284
269
285
NnNetExecution execution (args->nThreads , &netConfig);
270
286
287
+ std::unique_ptr<NnDevice> device (createDevice (args, &netConfig, &nodeConfig, &execution));
288
+
271
289
NnNetworkNodeSynchronizer synchronizer (network, &execution, &netConfig, &nodeConfig);
272
- NnCpuDevice cpu (&netConfig, &nodeConfig, &execution);
273
- NnExecutor executor (&netConfig, &nodeConfig, &cpu, &execution, &synchronizer, false );
290
+ NnExecutor executor (&netConfig, &nodeConfig, device.get (), &execution, &synchronizer, false );
274
291
275
292
NnWorkerWeightReader weightReader (&executor, network);
276
293
weightReader.read ();
0 commit comments