@@ -59,6 +59,7 @@ static std::pair<vk::Buffer, vk::DeviceMemory> createBuffer(const NnVulkanContex
59
59
}
60
60
61
61
NnVulkanStagingCopy::NnVulkanStagingCopy (const NnVulkanContext *context, vk::Buffer& deviceBuffer, const vk::DeviceSize bufferSize, const NnStagingVulkanCopyDirection direction) {
62
+ this ->direction = direction;
62
63
this ->deviceBuffer = deviceBuffer;
63
64
this ->context = context;
64
65
this ->bufferSize = bufferSize;
@@ -89,7 +90,7 @@ void NnVulkanStagingCopy::copy(NnByte *data) {
89
90
}
90
91
}
91
92
92
- void NnVulkanStagingCopy::addCopyCommand (vk::CommandBuffer commandBuffer) {
93
+ void NnVulkanStagingCopy::addCopyCommand (vk::CommandBuffer& commandBuffer) {
93
94
VkBufferCopy copyRegion = { 0 };
94
95
copyRegion.size = bufferSize;
95
96
switch (direction) {
@@ -105,6 +106,7 @@ void NnVulkanStagingCopy::addCopyCommand(vk::CommandBuffer commandBuffer) {
105
106
NnVulkanBuffer::NnVulkanBuffer (NnVulkanContext *context, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
106
107
this ->context = context;
107
108
this ->bufferSize = bufferSize;
109
+ this ->usageFlags = usageFlags;
108
110
this ->hostPointer = nullptr ;
109
111
110
112
uint32_t memoryTypeIndex = MEMORY_TYPE_INDEX_NOT_FOUND;
@@ -160,16 +162,23 @@ void NnVulkanBuffer::write(const NnByte *data) {
160
162
VULKAN_TRACE (" Wrote %lld bytes to buffer" , bufferSize);
161
163
}
162
164
165
+ void NnVulkanBuffer::read (NnByte *data) {
166
+ // TODO: this function should be deleted
167
+ assert (isHostVisible && hostPointer != nullptr );
168
+ std::memcpy (data, hostPointer, bufferSize);
169
+ VULKAN_TRACE (" Read %lld bytes from buffer" , bufferSize);
170
+ }
171
+
163
172
NnVulkanData::NnVulkanData (NnVulkanContext *context, NnNetConfig *netConfig, NnNodeConfig *nodeConfig)
164
173
: pipes(netConfig->nPipes), buffers(nodeConfig->nBuffers), internalBuffers()
165
174
{
166
175
this ->netConfig = netConfig;
167
176
this ->nodeConfig = nodeConfig;
168
177
169
178
for (NnUint i = 0 ; i < netConfig->nPipes ; i++)
170
- pipes[i].reset (new NnVulkanBuffer (context, netConfig->pipes [i].size .nBytes , vk::BufferUsageFlagBits::eUniformBuffer , true ));
179
+ pipes[i].reset (new NnVulkanBuffer (context, netConfig->pipes [i].size .nBytes , vk::BufferUsageFlagBits::eStorageBuffer , true ));
171
180
for (NnUint i = 0 ; i < nodeConfig->nBuffers ; i++)
172
- buffers[i].reset (new NnVulkanBuffer (context, nodeConfig->buffers [i].size .nBytes , vk::BufferUsageFlagBits::eUniformBuffer , false ));
181
+ buffers[i].reset (new NnVulkanBuffer (context, nodeConfig->buffers [i].size .nBytes , vk::BufferUsageFlagBits::eStorageBuffer , false ));
173
182
}
174
183
175
184
NnVulkanData::~NnVulkanData () {
@@ -287,7 +296,7 @@ NnUint NnVulkanDevice::maxNThreads() {
287
296
288
297
NnDeviceSegment *NnVulkanDevice::createSegment (NnUint segmentIndex) {
289
298
NnSegmentConfig *segmentConfig = &nodeConfig->segments [segmentIndex];
290
- return new NnVulkanDeviceSegment (&context, data, segmentConfig);
299
+ return new NnVulkanDeviceSegment (&context, data, segmentConfig, netExecution );
291
300
};
292
301
293
302
void NnVulkanDevice::syncPointers () {
@@ -321,18 +330,50 @@ NnVulkanShader::NnVulkanShader(const char *fileName) {
321
330
fclose (file);
322
331
}
323
332
324
- NnVulkanDeviceSegment::NnVulkanDeviceSegment (NnVulkanContext *context, NnVulkanData *data, NnSegmentConfig *segmentConfig)
325
- :
326
- weightBufferIndex(segmentConfig->nOps),
327
- configBufferIndex(segmentConfig->nOps),
333
+ static vk::DescriptorType toDescriptorType (NnVulkanBuffer *buffer) {
334
+ if (buffer->usageFlags & vk::BufferUsageFlagBits::eUniformBuffer)
335
+ return vk::DescriptorType::eUniformBuffer;
336
+ if (buffer->usageFlags & vk::BufferUsageFlagBits::eStorageBuffer)
337
+ return vk::DescriptorType::eStorageBuffer;
338
+ throw std::invalid_argument (" Unsupported buffer usage" );
339
+ }
340
+
341
+ static void modifySpvSet (std::vector<uint32_t >& binary, uint32_t new_set) {
342
+ if (binary.size () < 5 )
343
+ throw std::runtime_error (" Invalid SPIR-V binary: too short" );
344
+
345
+ uint32_t magic = binary[0 ];
346
+ if (magic != 0x07230203 )
347
+ throw std::runtime_error (" Unsupported endianness or not a SPIR-V binary" );
348
+
349
+ size_t index = 5 ;
350
+ while (index < binary.size ()) {
351
+ uint32_t firstWord = binary[index ];
352
+ uint16_t opcode = firstWord & 0xFFFF ;
353
+ uint16_t wordCount = firstWord >> 16 ;
354
+ if (wordCount == 0 ) break ;
355
+ if (opcode == 71 && wordCount >= 4 ) {
356
+ uint32_t decoration = binary[index + 2 ];
357
+ if (decoration == 34 )
358
+ binary[index + 3 ] = new_set;
359
+ }
360
+ index += wordCount;
361
+ }
362
+ }
363
+
364
+ NnVulkanDeviceSegment::NnVulkanDeviceSegment (NnVulkanContext *context, NnVulkanData *data, NnSegmentConfig *segmentConfig, NnNetExecution *netExecution) :
365
+ weightBufferIndex(segmentConfig->nOps, UINT32_MAX),
366
+ configBufferIndex(segmentConfig->nOps, UINT32_MAX),
328
367
shaderModules(segmentConfig->nOps),
329
368
descriptorSets(segmentConfig->nOps),
330
369
descriptorPools(segmentConfig->nOps),
331
- descriptorSetLayouts(segmentConfig->nOps)
370
+ descriptorSetLayouts(segmentConfig->nOps),
371
+ groupCountX(segmentConfig->nOps)
332
372
{
333
373
this ->context = context;
334
374
this ->data = data;
335
375
this ->segmentConfig = segmentConfig;
376
+ this ->netExecution = netExecution;
336
377
337
378
for (NnUint opIndex = 0 ; opIndex < segmentConfig->nOps ; opIndex++) {
338
379
NnOpConfig *opConfig = &segmentConfig->ops [opIndex];
@@ -362,70 +403,66 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
362
403
const char *shaderFileName = getShaderFileName (opConfig->code , opQuant);
363
404
assert (shaderFileName != nullptr );
364
405
NnVulkanShader shader (shaderFileName);
406
+ modifySpvSet (shader.code , opIndex);
365
407
366
408
vk::ShaderModuleCreateInfo shaderModuleCreateInfo (vk::ShaderModuleCreateFlags (), shader.code .size (), shader.code .data ());
367
409
vk::ShaderModule shaderModule = context->device .createShaderModule (shaderModuleCreateInfo);
368
410
369
411
vk::PipelineShaderStageCreateInfo shaderCreateInfo (vk::PipelineShaderStageCreateFlags (), vk::ShaderStageFlagBits::eCompute, shaderModule, " main" );
370
412
371
- std::vector<vk::DescriptorType> descriptorTypes;
372
413
std::vector<NnVulkanBuffer *> buffers;
373
-
374
414
{
375
415
// input
376
- descriptorTypes.push_back (vk::DescriptorType::eUniformBuffer);
377
416
buffers.push_back (data->resolveBuffer (&opConfig->input ));
378
-
379
417
// output
380
- descriptorTypes.push_back (vk::DescriptorType::eUniformBuffer);
381
418
buffers.push_back (data->resolveBuffer (&opConfig->output ));
382
-
383
419
// weight
384
- /* if (opConfig->weightSize.nBytes > 0) {
385
- descriptorTypes.push_back(vk::DescriptorType::eStorageBuffer );
420
+ if (opConfig->weightSize .nBytes > 0 ) {
421
+ assert (weightBufferIndex[opIndex] != UINT32_MAX );
386
422
buffers.push_back (data->internalBuffers [weightBufferIndex[opIndex]].get ());
387
- }*/
388
-
423
+ }
389
424
// config
390
425
if (opConfig->configSize > 0 ) {
391
- descriptorTypes. push_back (vk::DescriptorType::eUniformBuffer );
426
+ assert (configBufferIndex[opIndex] != UINT32_MAX );
392
427
buffers.push_back (data->internalBuffers [configBufferIndex[opIndex]].get ());
393
428
}
394
429
}
395
430
396
- std::vector<vk::DescriptorSetLayoutBinding> descriptorSetLayoutBindings;
397
- for (NnUint i = 0 ; i < descriptorTypes.size (); i++) {
398
- vk::DescriptorSetLayoutBinding binding (i, descriptorTypes[i], 1 , vk::ShaderStageFlagBits::eCompute);
399
- descriptorSetLayoutBindings.push_back (binding);
400
- }
431
+ std::vector<vk::DescriptorSetLayoutBinding> descriptorSetLayoutBindings (buffers.size ());
432
+ for (NnUint i = 0 ; i < buffers.size (); i++)
433
+ descriptorSetLayoutBindings[i] = vk::DescriptorSetLayoutBinding (i, toDescriptorType (buffers[i]), 1 , vk::ShaderStageFlagBits::eCompute);
401
434
402
435
vk::DescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo (vk::DescriptorSetLayoutCreateFlags (), descriptorSetLayoutBindings.size (), descriptorSetLayoutBindings.data ());
403
436
vk::DescriptorSetLayout descriptorSetLayout = context->device .createDescriptorSetLayout (descriptorSetLayoutCreateInfo);
404
437
405
438
NnUint nUniformBuffers = 0 ;
406
439
NnUint nStorageBuffers = 0 ;
407
- for (NnUint i = 0 ; i < descriptorTypes.size (); i++) {
408
- if (descriptorTypes[i] == vk::DescriptorType::eUniformBuffer)
440
+ for (NnUint i = 0 ; i < buffers.size (); i++) {
441
+ vk::DescriptorType descriptorType = toDescriptorType (buffers[i]);
442
+ if (descriptorType == vk::DescriptorType::eUniformBuffer)
409
443
nUniformBuffers++;
410
- if (descriptorTypes[i] == vk::DescriptorType::eStorageBuffer)
444
+ if (descriptorType == vk::DescriptorType::eStorageBuffer)
411
445
nStorageBuffers++;
412
446
}
447
+
413
448
std::vector<vk::DescriptorPoolSize> descriptorPoolSizes;
414
- if (nUniformBuffers > 0 )
415
- descriptorPoolSizes.push_back (vk::DescriptorPoolSize (vk::DescriptorType::eUniformBuffer, nUniformBuffers));
416
449
if (nStorageBuffers > 0 )
417
450
descriptorPoolSizes.push_back (vk::DescriptorPoolSize (vk::DescriptorType::eStorageBuffer, nStorageBuffers));
451
+ if (nUniformBuffers > 0 )
452
+ descriptorPoolSizes.push_back (vk::DescriptorPoolSize (vk::DescriptorType::eUniformBuffer, nUniformBuffers));
418
453
419
- vk::DescriptorPoolCreateInfo descriptorPoolCreateInfo (vk::DescriptorPoolCreateFlags () , 1 , descriptorPoolSizes.size (), descriptorPoolSizes.data ());
454
+ vk::DescriptorPoolCreateInfo descriptorPoolCreateInfo (vk::DescriptorPoolCreateFlagBits::eFreeDescriptorSet , 1 , descriptorPoolSizes.size (), descriptorPoolSizes.data ());
420
455
vk::DescriptorPool descriptorPool = context->device .createDescriptorPool (descriptorPoolCreateInfo);
421
456
vk::DescriptorSetAllocateInfo descriptorSetAllocInfo (descriptorPool, 1 , &descriptorSetLayout);
422
457
const std::vector<vk::DescriptorSet> allocatedDescriptorSets = context->device .allocateDescriptorSets (descriptorSetAllocInfo);
458
+ assert (allocatedDescriptorSets.size () == 1 );
459
+
423
460
vk::DescriptorSet descriptorSet = allocatedDescriptorSets[0 ];
424
- std::vector<vk::WriteDescriptorSet> writeDescriptorSets ;
425
- for (NnUint i = 0 ; i < descriptorTypes .size (); i++) {
426
- vk::DescriptorBufferInfo bufferInfo (buffers[i]-> deviceBuffer , 0 , buffers[i]-> bufferSize );
427
- vk::WriteDescriptorSet writeDescriptorSet (descriptorSet, i , 0 , 1 , descriptorTypes [i], nullptr , &bufferInfo, nullptr );
428
- writeDescriptorSets. push_back (writeDescriptorSet );
461
+ std::vector<vk::DescriptorBufferInfo> bufferInfos (buffers. size ()) ;
462
+ std::vector<vk::WriteDescriptorSet> writeDescriptorSets (buffers .size ());
463
+ for (NnUint i = 0 ; i < buffers. size (); i++) {
464
+ bufferInfos[i] = vk::DescriptorBufferInfo (buffers[i]-> deviceBuffer , 0 , buffers [i]-> bufferSize );
465
+ writeDescriptorSets[i] = vk::WriteDescriptorSet (descriptorSet, i, 0 , 1 , toDescriptorType (buffers[i]), nullptr , &bufferInfos[i], nullptr );
429
466
}
430
467
431
468
context->device .updateDescriptorSets (writeDescriptorSets, nullptr );
@@ -435,6 +472,8 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
435
472
descriptorSets[opIndex] = descriptorSet;
436
473
descriptorPools[opIndex] = descriptorPool;
437
474
descriptorSetLayouts[opIndex] = descriptorSetLayout;
475
+ groupCountX[opIndex] = inputSize.x / 32 ;
476
+ VULKAN_TRACE (" Shader %d groupCountX=%d" , opIndex, groupCountX[opIndex]);
438
477
}
439
478
440
479
vk::PipelineLayoutCreateInfo pipelineLayoutCreateInfo (vk::PipelineLayoutCreateFlags (), descriptorSetLayouts.size (), descriptorSetLayouts.data ());
@@ -478,6 +517,15 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
478
517
return ;
479
518
}
480
519
520
+ {
521
+ // TODO
522
+ for (NnUint opIndex = 0 ; opIndex < segmentConfig->nOps ; opIndex++) {
523
+ NnOpConfig *opConfig = &segmentConfig->ops [opIndex];
524
+ if (opConfig->input .pointerType == PNTR_PIPE)
525
+ data->pipes [opConfig->input .pointerIndex ]->write (netExecution->pipes [opConfig->input .pointerIndex ]);
526
+ }
527
+ }
528
+
481
529
vk::CommandBufferAllocateInfo commandBufferAllocInfo (context->commandPool , vk::CommandBufferLevel::ePrimary, 1 );
482
530
const std::vector<vk::CommandBuffer> cmdBuffers = context->device .allocateCommandBuffers (commandBufferAllocInfo);
483
531
vk::CommandBuffer commandBuffer = cmdBuffers.front ();
@@ -487,15 +535,24 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
487
535
for (NnUint opIndex = 0 ; opIndex < segmentConfig->nOps ; opIndex++) {
488
536
commandBuffer.bindPipeline (vk::PipelineBindPoint::eCompute, pipelines[opIndex]);
489
537
commandBuffer.bindDescriptorSets (vk::PipelineBindPoint::eCompute, pipelineLayout, opIndex, { descriptorSets[opIndex] }, {});
490
- commandBuffer.dispatch (1 , 1 , 1 );
538
+ commandBuffer.dispatch (groupCountX[opIndex] , 1 , 1 );
491
539
}
492
540
commandBuffer.end ();
493
541
494
542
context->device .resetFences ({ fence });
495
543
vk::SubmitInfo submitInfo (0 , nullptr , nullptr , 1 , &commandBuffer);
496
544
context->queue .submit ({ submitInfo }, fence);
497
545
assert (context->device .waitForFences ({ fence }, true , uint64_t (-1 )) == vk::Result::eSuccess);
546
+ context->device .freeCommandBuffers (context->commandPool , 1 , &commandBuffer);
547
+
498
548
VULKAN_TRACE (" Forwarded" );
499
549
500
- context->device .freeCommandBuffers (context->commandPool , 1 , &commandBuffer);
550
+ {
551
+ // TODO
552
+ for (NnUint opIndex = 0 ; opIndex < segmentConfig->nOps ; opIndex++) {
553
+ NnOpConfig *opConfig = &segmentConfig->ops [opIndex];
554
+ if (opConfig->output .pointerType == PNTR_PIPE)
555
+ data->pipes [opConfig->output .pointerIndex ]->read (netExecution->pipes [opConfig->output .pointerIndex ]);
556
+ }
557
+ }
501
558
}
0 commit comments