Skip to content

Commit c8b3cf9

Browse files
committed
fix.
1 parent 40cdcbc commit c8b3cf9

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

src/app.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ AppCliArgs::~AppCliArgs() {
128128
static NnDevice *createDevice(AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
129129
if (args->gpuIndex >= 0) {
130130
#if defined(DLLAMA_VULKAN)
131+
args->nBatches = 1; // TODO: this should be fixed
131132
return new NnVulkanDevice(args->gpuIndex, netConfig, nodeConfig, netExecution);
132133
#else
133134
throw std::runtime_error("This build does not support GPU");

src/nn/nn-vulkan.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,12 @@ static std::vector<NnVulkanBatchInfo> buildBatchInfo(NnOpConfig *opConfig, NnVul
456456
return offset;
457457
}
458458

459-
static void resolveShaderGroups(const NnOpCode opCode, const NnUint batchSize, NnUint *groupCount) {
459+
static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSize, NnUint *groupCount) {
460460
groupCount[0] = 1;
461461
groupCount[1] = batchSize;
462462
groupCount[2] = 1;
463463

464-
if (opCode == OP_MATMUL)
464+
if (opConfig->code == OP_MATMUL)
465465
groupCount[2] = 32;
466466
}
467467

@@ -732,11 +732,11 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
732732

733733
if (lastBatchSize != batchSize) {
734734
lastBatchSize = batchSize;
735-
commandBuffer.begin({ vk::CommandBufferUsageFlags{ vk::CommandBufferUsageFlagBits::eSimultaneousUse } });
735+
commandBuffer.begin({ vk::CommandBufferUsageFlags{} });
736736

737737
NnUint opGroupCount[3];
738738
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
739-
resolveShaderGroups(segmentConfig->ops[opIndex].code, batchSize, opGroupCount);
739+
resolveShaderGroups(&segmentConfig->ops[opIndex], batchSize, opGroupCount);
740740

741741
if (opIndex > 0) {
742742
vk::MemoryBarrier memoryBarrier(

src/nn/vulkan/multi-head-att-forward-f32-f32.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#define N_THREADS 32
3+
#define N_THREADS 256
44

55
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
66

0 commit comments

Comments
 (0)