Skip to content

Commit 5479d35

Browse files
committed
forward test.
1 parent 61b50f3 commit 5479d35

File tree

5 files changed

+122
-47
lines changed

5 files changed

+122
-47
lines changed

src/nn/nn-vulkan-test.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,17 @@ int main() {
5050

5151
float rmsNormWeight[DIM];
5252
for (NnUint i = 0; i < DIM; i++)
53-
rmsNormWeight[i] = 0.5 + i / (float)DIM;
53+
rmsNormWeight[i] = i / (float)DIM;
5454

5555
NnExecutor executor(&netConfig, &nodeConfig, &device, &execution, &synchronizer);
5656
executor.loadWeight("rms_norm", 0, sizeof(rmsNormWeight), (NnByte *)rmsNormWeight);
5757

5858
execution.setBatchSize(N_BATCHES);
5959
executor.forward();
60+
61+
printf("output: ");
62+
for (NnUint i = 0; i < DIM; i++)
63+
printf("%.2f ", x[i]);
64+
printf("\n");
6065
return 0;
6166
}

src/nn/nn-vulkan.cpp

+95-38
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ static std::pair<vk::Buffer, vk::DeviceMemory> createBuffer(const NnVulkanContex
5959
}
6060

6161
NnVulkanStagingCopy::NnVulkanStagingCopy(const NnVulkanContext *context, vk::Buffer& deviceBuffer, const vk::DeviceSize bufferSize, const NnStagingVulkanCopyDirection direction) {
62+
this->direction = direction;
6263
this->deviceBuffer = deviceBuffer;
6364
this->context = context;
6465
this->bufferSize = bufferSize;
@@ -89,7 +90,7 @@ void NnVulkanStagingCopy::copy(NnByte *data) {
8990
}
9091
}
9192

92-
void NnVulkanStagingCopy::addCopyCommand(vk::CommandBuffer commandBuffer) {
93+
void NnVulkanStagingCopy::addCopyCommand(vk::CommandBuffer& commandBuffer) {
9394
VkBufferCopy copyRegion = { 0 };
9495
copyRegion.size = bufferSize;
9596
switch (direction) {
@@ -105,6 +106,7 @@ void NnVulkanStagingCopy::addCopyCommand(vk::CommandBuffer commandBuffer) {
105106
NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
106107
this->context = context;
107108
this->bufferSize = bufferSize;
109+
this->usageFlags = usageFlags;
108110
this->hostPointer = nullptr;
109111

110112
uint32_t memoryTypeIndex = MEMORY_TYPE_INDEX_NOT_FOUND;
@@ -160,16 +162,23 @@ void NnVulkanBuffer::write(const NnByte *data) {
160162
VULKAN_TRACE("Wrote %lld bytes to buffer", bufferSize);
161163
}
162164

165+
void NnVulkanBuffer::read(NnByte *data) {
166+
// TODO: this function should be deleted
167+
assert(isHostVisible && hostPointer != nullptr);
168+
std::memcpy(data, hostPointer, bufferSize);
169+
VULKAN_TRACE("Read %lld bytes from buffer", bufferSize);
170+
}
171+
163172
NnVulkanData::NnVulkanData(NnVulkanContext *context, NnNetConfig *netConfig, NnNodeConfig *nodeConfig)
164173
: pipes(netConfig->nPipes), buffers(nodeConfig->nBuffers), internalBuffers()
165174
{
166175
this->netConfig = netConfig;
167176
this->nodeConfig = nodeConfig;
168177

169178
for (NnUint i = 0; i < netConfig->nPipes; i++)
170-
pipes[i].reset(new NnVulkanBuffer(context, netConfig->pipes[i].size.nBytes, vk::BufferUsageFlagBits::eUniformBuffer, true));
179+
pipes[i].reset(new NnVulkanBuffer(context, netConfig->pipes[i].size.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, true));
171180
for (NnUint i = 0; i < nodeConfig->nBuffers; i++)
172-
buffers[i].reset(new NnVulkanBuffer(context, nodeConfig->buffers[i].size.nBytes, vk::BufferUsageFlagBits::eUniformBuffer, false));
181+
buffers[i].reset(new NnVulkanBuffer(context, nodeConfig->buffers[i].size.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, false));
173182
}
174183

175184
NnVulkanData::~NnVulkanData() {
@@ -287,7 +296,7 @@ NnUint NnVulkanDevice::maxNThreads() {
287296

288297
NnDeviceSegment *NnVulkanDevice::createSegment(NnUint segmentIndex) {
289298
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
290-
return new NnVulkanDeviceSegment(&context, data, segmentConfig);
299+
return new NnVulkanDeviceSegment(&context, data, segmentConfig, netExecution);
291300
};
292301

293302
void NnVulkanDevice::syncPointers() {
@@ -321,18 +330,50 @@ NnVulkanShader::NnVulkanShader(const char *fileName) {
321330
fclose(file);
322331
}
323332

324-
NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanData *data, NnSegmentConfig *segmentConfig)
325-
:
326-
weightBufferIndex(segmentConfig->nOps),
327-
configBufferIndex(segmentConfig->nOps),
333+
static vk::DescriptorType toDescriptorType(NnVulkanBuffer *buffer) {
334+
if (buffer->usageFlags & vk::BufferUsageFlagBits::eUniformBuffer)
335+
return vk::DescriptorType::eUniformBuffer;
336+
if (buffer->usageFlags & vk::BufferUsageFlagBits::eStorageBuffer)
337+
return vk::DescriptorType::eStorageBuffer;
338+
throw std::invalid_argument("Unsupported buffer usage");
339+
}
340+
341+
static void modifySpvSet(std::vector<uint32_t>& binary, uint32_t new_set) {
342+
if (binary.size() < 5)
343+
throw std::runtime_error("Invalid SPIR-V binary: too short");
344+
345+
uint32_t magic = binary[0];
346+
if (magic != 0x07230203)
347+
throw std::runtime_error("Unsupported endianness or not a SPIR-V binary");
348+
349+
size_t index = 5;
350+
while (index < binary.size()) {
351+
uint32_t firstWord = binary[index];
352+
uint16_t opcode = firstWord & 0xFFFF;
353+
uint16_t wordCount = firstWord >> 16;
354+
if (wordCount == 0) break;
355+
if (opcode == 71 && wordCount >= 4) {
356+
uint32_t decoration = binary[index + 2];
357+
if (decoration == 34)
358+
binary[index + 3] = new_set;
359+
}
360+
index += wordCount;
361+
}
362+
}
363+
364+
NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanData *data, NnSegmentConfig *segmentConfig, NnNetExecution *netExecution) :
365+
weightBufferIndex(segmentConfig->nOps, UINT32_MAX),
366+
configBufferIndex(segmentConfig->nOps, UINT32_MAX),
328367
shaderModules(segmentConfig->nOps),
329368
descriptorSets(segmentConfig->nOps),
330369
descriptorPools(segmentConfig->nOps),
331-
descriptorSetLayouts(segmentConfig->nOps)
370+
descriptorSetLayouts(segmentConfig->nOps),
371+
groupCountX(segmentConfig->nOps)
332372
{
333373
this->context = context;
334374
this->data = data;
335375
this->segmentConfig = segmentConfig;
376+
this->netExecution = netExecution;
336377

337378
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
338379
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
@@ -362,70 +403,66 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
362403
const char *shaderFileName = getShaderFileName(opConfig->code, opQuant);
363404
assert(shaderFileName != nullptr);
364405
NnVulkanShader shader(shaderFileName);
406+
modifySpvSet(shader.code, opIndex);
365407

366408
vk::ShaderModuleCreateInfo shaderModuleCreateInfo(vk::ShaderModuleCreateFlags(), shader.code.size(), shader.code.data());
367409
vk::ShaderModule shaderModule = context->device.createShaderModule(shaderModuleCreateInfo);
368410

369411
vk::PipelineShaderStageCreateInfo shaderCreateInfo(vk::PipelineShaderStageCreateFlags(), vk::ShaderStageFlagBits::eCompute, shaderModule, "main");
370412

371-
std::vector<vk::DescriptorType> descriptorTypes;
372413
std::vector<NnVulkanBuffer *> buffers;
373-
374414
{
375415
// input
376-
descriptorTypes.push_back(vk::DescriptorType::eUniformBuffer);
377416
buffers.push_back(data->resolveBuffer(&opConfig->input));
378-
379417
// output
380-
descriptorTypes.push_back(vk::DescriptorType::eUniformBuffer);
381418
buffers.push_back(data->resolveBuffer(&opConfig->output));
382-
383419
// weight
384-
/*if (opConfig->weightSize.nBytes > 0) {
385-
descriptorTypes.push_back(vk::DescriptorType::eStorageBuffer);
420+
if (opConfig->weightSize.nBytes > 0) {
421+
assert(weightBufferIndex[opIndex] != UINT32_MAX);
386422
buffers.push_back(data->internalBuffers[weightBufferIndex[opIndex]].get());
387-
}*/
388-
423+
}
389424
// config
390425
if (opConfig->configSize > 0) {
391-
descriptorTypes.push_back(vk::DescriptorType::eUniformBuffer);
426+
assert(configBufferIndex[opIndex] != UINT32_MAX);
392427
buffers.push_back(data->internalBuffers[configBufferIndex[opIndex]].get());
393428
}
394429
}
395430

396-
std::vector<vk::DescriptorSetLayoutBinding> descriptorSetLayoutBindings;
397-
for (NnUint i = 0; i < descriptorTypes.size(); i++) {
398-
vk::DescriptorSetLayoutBinding binding(i, descriptorTypes[i], 1, vk::ShaderStageFlagBits::eCompute);
399-
descriptorSetLayoutBindings.push_back(binding);
400-
}
431+
std::vector<vk::DescriptorSetLayoutBinding> descriptorSetLayoutBindings(buffers.size());
432+
for (NnUint i = 0; i < buffers.size(); i++)
433+
descriptorSetLayoutBindings[i] = vk::DescriptorSetLayoutBinding(i, toDescriptorType(buffers[i]), 1, vk::ShaderStageFlagBits::eCompute);
401434

402435
vk::DescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo(vk::DescriptorSetLayoutCreateFlags(), descriptorSetLayoutBindings.size(), descriptorSetLayoutBindings.data());
403436
vk::DescriptorSetLayout descriptorSetLayout = context->device.createDescriptorSetLayout(descriptorSetLayoutCreateInfo);
404437

405438
NnUint nUniformBuffers = 0;
406439
NnUint nStorageBuffers = 0;
407-
for (NnUint i = 0; i < descriptorTypes.size(); i++) {
408-
if (descriptorTypes[i] == vk::DescriptorType::eUniformBuffer)
440+
for (NnUint i = 0; i < buffers.size(); i++) {
441+
vk::DescriptorType descriptorType = toDescriptorType(buffers[i]);
442+
if (descriptorType == vk::DescriptorType::eUniformBuffer)
409443
nUniformBuffers++;
410-
if (descriptorTypes[i] == vk::DescriptorType::eStorageBuffer)
444+
if (descriptorType == vk::DescriptorType::eStorageBuffer)
411445
nStorageBuffers++;
412446
}
447+
413448
std::vector<vk::DescriptorPoolSize> descriptorPoolSizes;
414-
if (nUniformBuffers > 0)
415-
descriptorPoolSizes.push_back(vk::DescriptorPoolSize(vk::DescriptorType::eUniformBuffer, nUniformBuffers));
416449
if (nStorageBuffers > 0)
417450
descriptorPoolSizes.push_back(vk::DescriptorPoolSize(vk::DescriptorType::eStorageBuffer, nStorageBuffers));
451+
if (nUniformBuffers > 0)
452+
descriptorPoolSizes.push_back(vk::DescriptorPoolSize(vk::DescriptorType::eUniformBuffer, nUniformBuffers));
418453

419-
vk::DescriptorPoolCreateInfo descriptorPoolCreateInfo(vk::DescriptorPoolCreateFlags(), 1, descriptorPoolSizes.size(), descriptorPoolSizes.data());
454+
vk::DescriptorPoolCreateInfo descriptorPoolCreateInfo(vk::DescriptorPoolCreateFlagBits::eFreeDescriptorSet, 1, descriptorPoolSizes.size(), descriptorPoolSizes.data());
420455
vk::DescriptorPool descriptorPool = context->device.createDescriptorPool(descriptorPoolCreateInfo);
421456
vk::DescriptorSetAllocateInfo descriptorSetAllocInfo(descriptorPool, 1, &descriptorSetLayout);
422457
const std::vector<vk::DescriptorSet> allocatedDescriptorSets = context->device.allocateDescriptorSets(descriptorSetAllocInfo);
458+
assert(allocatedDescriptorSets.size() == 1);
459+
423460
vk::DescriptorSet descriptorSet = allocatedDescriptorSets[0];
424-
std::vector<vk::WriteDescriptorSet> writeDescriptorSets;
425-
for (NnUint i = 0; i < descriptorTypes.size(); i++) {
426-
vk::DescriptorBufferInfo bufferInfo(buffers[i]->deviceBuffer, 0, buffers[i]->bufferSize);
427-
vk::WriteDescriptorSet writeDescriptorSet(descriptorSet, i, 0, 1, descriptorTypes[i], nullptr, &bufferInfo, nullptr);
428-
writeDescriptorSets.push_back(writeDescriptorSet);
461+
std::vector<vk::DescriptorBufferInfo> bufferInfos(buffers.size());
462+
std::vector<vk::WriteDescriptorSet> writeDescriptorSets(buffers.size());
463+
for (NnUint i = 0; i < buffers.size(); i++) {
464+
bufferInfos[i] = vk::DescriptorBufferInfo(buffers[i]->deviceBuffer, 0, buffers[i]->bufferSize);
465+
writeDescriptorSets[i] = vk::WriteDescriptorSet(descriptorSet, i, 0, 1, toDescriptorType(buffers[i]), nullptr, &bufferInfos[i], nullptr);
429466
}
430467

431468
context->device.updateDescriptorSets(writeDescriptorSets, nullptr);
@@ -435,6 +472,8 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
435472
descriptorSets[opIndex] = descriptorSet;
436473
descriptorPools[opIndex] = descriptorPool;
437474
descriptorSetLayouts[opIndex] = descriptorSetLayout;
475+
groupCountX[opIndex] = inputSize.x / 32;
476+
VULKAN_TRACE("Shader %d groupCountX=%d", opIndex, groupCountX[opIndex]);
438477
}
439478

440479
vk::PipelineLayoutCreateInfo pipelineLayoutCreateInfo(vk::PipelineLayoutCreateFlags(), descriptorSetLayouts.size(), descriptorSetLayouts.data());
@@ -478,6 +517,15 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
478517
return;
479518
}
480519

520+
{
521+
// TODO
522+
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
523+
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
524+
if (opConfig->input.pointerType == PNTR_PIPE)
525+
data->pipes[opConfig->input.pointerIndex]->write(netExecution->pipes[opConfig->input.pointerIndex]);
526+
}
527+
}
528+
481529
vk::CommandBufferAllocateInfo commandBufferAllocInfo(context->commandPool, vk::CommandBufferLevel::ePrimary, 1);
482530
const std::vector<vk::CommandBuffer> cmdBuffers = context->device.allocateCommandBuffers(commandBufferAllocInfo);
483531
vk::CommandBuffer commandBuffer = cmdBuffers.front();
@@ -487,15 +535,24 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
487535
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
488536
commandBuffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipelines[opIndex]);
489537
commandBuffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipelineLayout, opIndex, { descriptorSets[opIndex] }, {});
490-
commandBuffer.dispatch(1, 1, 1);
538+
commandBuffer.dispatch(groupCountX[opIndex], 1, 1);
491539
}
492540
commandBuffer.end();
493541

494542
context->device.resetFences({ fence });
495543
vk::SubmitInfo submitInfo(0, nullptr, nullptr, 1, &commandBuffer);
496544
context->queue.submit({ submitInfo }, fence);
497545
assert(context->device.waitForFences({ fence }, true, uint64_t(-1)) == vk::Result::eSuccess);
546+
context->device.freeCommandBuffers(context->commandPool, 1, &commandBuffer);
547+
498548
VULKAN_TRACE("Forwarded");
499549

500-
context->device.freeCommandBuffers(context->commandPool, 1, &commandBuffer);
550+
{
551+
// TODO
552+
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
553+
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
554+
if (opConfig->output.pointerType == PNTR_PIPE)
555+
data->pipes[opConfig->output.pointerIndex]->read(netExecution->pipes[opConfig->output.pointerIndex]);
556+
}
557+
}
501558
}

src/nn/nn-vulkan.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class NnVulkanStagingCopy {
3535
NnVulkanStagingCopy(const NnVulkanContext *context, vk::Buffer& deviceBuffer, const vk::DeviceSize bufferSize, const NnStagingVulkanCopyDirection direction);
3636
~NnVulkanStagingCopy();
3737
void copy(NnByte *data);
38-
void addCopyCommand(vk::CommandBuffer commandBuffer);
38+
void addCopyCommand(vk::CommandBuffer& commandBuffer);
3939
};
4040

4141
class NnVulkanBuffer {
@@ -47,9 +47,11 @@ class NnVulkanBuffer {
4747
public:
4848
vk::DeviceSize bufferSize;
4949
vk::Buffer deviceBuffer;
50+
vk::BufferUsageFlags usageFlags;
5051
NnVulkanBuffer(NnVulkanContext *context, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess);
5152
~NnVulkanBuffer();
5253
void write(const NnByte *data);
54+
void read(NnByte *data);
5355
};
5456

5557
class NnVulkanShader {
@@ -92,6 +94,7 @@ class NnVulkanDeviceSegment : public NnDeviceSegment {
9294
NnVulkanContext *context;
9395
NnVulkanData *data;
9496
NnSegmentConfig *segmentConfig;
97+
NnNetExecution *netExecution;
9598
std::vector<NnUint> weightBufferIndex;
9699
std::vector<NnUint> configBufferIndex;
97100

@@ -103,8 +106,9 @@ class NnVulkanDeviceSegment : public NnDeviceSegment {
103106
std::vector<vk::Pipeline> pipelines;
104107
vk::PipelineCache pipelineCache;
105108
vk::PipelineLayout pipelineLayout;
109+
std::vector<NnUint> groupCountX;
106110
public:
107-
NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanData *data, NnSegmentConfig *segmentConfig);
111+
NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanData *data, NnSegmentConfig *segmentConfig, NnNetExecution *netExecution);
108112
~NnVulkanDeviceSegment() override;
109113
void loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) override;
110114
void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override;

src/nn/vulkan/inv-rms-f32-f32.comp

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
88

9-
layout(binding = 0) readonly uniform inputBuffer { float inpt[]; };
10-
layout(binding = 1) writeonly uniform outputBuffer { float outp[]; };
11-
layout(binding = 2) readonly uniform configBuffer { float config[]; };
9+
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
10+
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
11+
layout(binding = 2) readonly uniform configBuffer { float epsilon; };
1212

1313
void main() {
14+
y[0] = epsilon;
1415
}

src/nn/vulkan/rms-norm-f32-f32-f32.comp

+11-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,17 @@
44
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
55
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
66

7-
layout(binding = 0) readonly uniform inputBuffer { float inpt[]; };
8-
layout(binding = 1) writeonly uniform outputBuffer { float outp[]; };
9-
layout(binding = 2) readonly uniform configBuffer { float config[]; };
7+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
8+
9+
layout(binding = 0) readonly buffer inputBuffer { float inpt[]; };
10+
layout(binding = 1) writeonly buffer outputBuffer { float outp[]; };
11+
layout(binding = 2) readonly buffer weightBuffer { float weight[]; };
12+
layout(binding = 3) readonly uniform configBuffer { uint invRmsBufferIndex; };
1013

1114
void main() {
15+
const uint x = uint(gl_LocalInvocationID.x);
16+
17+
for (uint i = 0; i < 32; i++) {
18+
outp[x * 32 + i] = inpt[x * 32 + i] * weight[x * 32 + i];
19+
}
1220
}

0 commit comments

Comments
 (0)