Skip to content

Commit df5765d

Browse files
committed
optimization.
1 parent 0a9e927 commit df5765d

7 files changed

+90
-42
lines changed

src/nn/nn-vulkan-test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ void testShift_F32_F32() {
309309
}
310310

311311
void testCast_F32_F32() {
312-
#define CAST_DIM 48
312+
#define CAST_DIM 64
313313
execute(
314314
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
315315
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, CAST_DIM));

src/nn/nn-vulkan.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,11 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
458458
groupCount[1] = batchSize;
459459
groupCount[2] = 1;
460460

461-
if (opConfig->code == OP_MATMUL)
461+
if (opConfig->code == OP_CAST ||
462+
opConfig->code == OP_MUL ||
463+
opConfig->code == OP_SILU ||
464+
opConfig->code == OP_MERGE_ADD ||
465+
opConfig->code == OP_MATMUL)
462466
groupCount[2] = 32;
463467
}
464468

src/nn/vulkan/cast-forward-f32-f32.comp

+16-7
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,33 @@ layout(binding = 0) readonly buffer inputBuffer { float x[]; };
1515
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
1616
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
1717

18-
shared BatchInfo sharedInfo;
18+
shared uint sharedDim;
19+
shared uint sharedXOffset;
20+
shared uint sharedYOffset;
1921

2022
void main() {
2123
const uint threadIndex = gl_LocalInvocationID.x;
22-
const uint batchIndex = gl_GlobalInvocationID.y;
2324

2425
if (threadIndex == 0) {
25-
sharedInfo = infos[batchIndex];
26+
const uint nWorkGroups = gl_NumWorkGroups.z;
27+
const uint batchIndex = gl_WorkGroupID.y;
28+
const uint workGroupIndex = gl_WorkGroupID.z;
29+
30+
const BatchInfo info = infos[batchIndex];
31+
sharedDim = info.inputSizeX / nWorkGroups;
32+
const uint dimOffset = sharedDim * workGroupIndex;
33+
sharedXOffset = info.inputOffset + dimOffset;
34+
sharedYOffset = info.outputOffset + dimOffset;
2635
}
2736

2837
barrier();
2938
memoryBarrierShared();
3039

31-
const uint inputSizeX = sharedInfo.inputSizeX;
32-
const uint xOffset = sharedInfo.inputOffset;
33-
const uint yOffset = sharedInfo.outputOffset;
40+
const uint dim = sharedDim;
41+
const uint xOffset = sharedXOffset;
42+
const uint yOffset = sharedYOffset;
3443

35-
for (uint i = threadIndex; i < inputSizeX; i += N_THREADS) {
44+
for (uint i = threadIndex; i < dim; i += N_THREADS) {
3645
y[yOffset + i] = x[xOffset + i];
3746
}
3847
}

src/nn/vulkan/matmul-forward-f32-f32-f32.comp

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#version 450
22

3-
#define N_WORK_GROUPS 32
43
#define N_THREADS 256
54

65
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
@@ -18,23 +17,27 @@ layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
1817
layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
1918

2019
shared BatchInfo sharedInfo;
20+
shared uint sharedDim;
2121

2222
void main() {
2323
const uint threadIndex = gl_LocalInvocationID.x;
24-
const uint workGroupIndex = gl_GlobalInvocationID.z;
25-
const uint batchIndex = gl_GlobalInvocationID.y;
24+
const uint workGroupIndex = gl_WorkGroupID.z;
2625

2726
if (threadIndex == 0) {
27+
const uint batchIndex = gl_WorkGroupID.y;
28+
const uint nWorkGroups = gl_NumWorkGroups.z;
29+
2830
sharedInfo = infos[batchIndex];
31+
sharedDim = sharedInfo.outputSizeX / nWorkGroups;
2932
}
3033

3134
barrier();
3235
memoryBarrierShared();
3336

3437
const uint inputSizeX = sharedInfo.inputSizeX;
35-
const uint dim = sharedInfo.outputSizeX / N_WORK_GROUPS;
3638
const uint xOffset = sharedInfo.inputOffset;
3739
const uint yOffset = sharedInfo.outputOffset;
40+
const uint dim = sharedDim;
3841

3942
for (uint i = threadIndex; i < dim; i += N_THREADS) {
4043
const uint d = (workGroupIndex * dim) + i;

src/nn/vulkan/merge-add-forward-f32-f32.comp

+27-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#extension GL_EXT_control_flow_attributes : enable
44

5-
#define N_THREADS 256
5+
#define N_THREADS 64
66

77
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
88

@@ -17,29 +17,42 @@ layout(binding = 0) readonly buffer inputBuffer { float x[]; };
1717
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
1818
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
1919

20-
shared BatchInfo sharedInfo;
20+
shared uint sharedDim;
21+
shared uint sharedOutputSizeX;
22+
shared uint sharedParts;
23+
shared uint sharedXOffset;
24+
shared uint sharedYOffset;
2125

2226
void main() {
2327
const uint threadIndex = gl_LocalInvocationID.x;
24-
const uint batchIndex = gl_GlobalInvocationID.y;
2528

2629
if (threadIndex == 0) {
27-
sharedInfo = infos[batchIndex];
30+
const uint nWorkGroups = gl_NumWorkGroups.z;
31+
const uint batchIndex = gl_WorkGroupID.y;
32+
const uint workGroupIndex = gl_WorkGroupID.z;
33+
34+
const BatchInfo info = infos[batchIndex];
35+
sharedDim = info.outputSizeX / nWorkGroups;
36+
sharedOutputSizeX = info.outputSizeX;
37+
sharedParts = info.inputSizeX / info.outputSizeX;
38+
sharedXOffset = info.inputOffset + sharedDim * workGroupIndex;
39+
sharedYOffset = info.outputOffset + sharedDim * workGroupIndex;
2840
}
29-
memoryBarrierShared();
41+
3042
barrier();
43+
memoryBarrierShared();
3144

32-
const uint inputSizeX = sharedInfo.inputSizeX;
33-
const uint inputOffset = sharedInfo.inputOffset;
34-
const uint outputOffset = sharedInfo.outputOffset;
35-
const uint outputSizeX = sharedInfo.outputSizeX;
36-
const uint nNodes = inputSizeX / outputSizeX;
45+
const uint dim = sharedDim;
46+
const uint outputSizeX = sharedOutputSizeX;
47+
const uint parts = sharedParts;
48+
const uint xOffset = sharedXOffset;
49+
const uint yOffset = sharedYOffset;
3750

38-
for (uint i = threadIndex; i < outputSizeX; i += N_THREADS) {
51+
for (uint i = threadIndex; i < dim; i += N_THREADS) {
3952
float sum = 0.0;
40-
const uint iOffset = inputOffset + i;
41-
const uint oOffset = outputOffset + i;
42-
for (uint n = 0; n < nNodes; n++) {
53+
const uint iOffset = xOffset + i;
54+
const uint oOffset = yOffset + i;
55+
for (uint n = 0; n < parts; n++) {
4356
sum += x[n * outputSizeX + iOffset];
4457
}
4558
y[oOffset] += sum;

src/nn/vulkan/mul-forward-f32-f32.comp

+18-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#define N_THREADS 64
44

55
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
66

@@ -19,23 +19,32 @@ layout(binding = 3) readonly uniform configBuffer {
1919
};
2020
layout(binding = 4) readonly buffer multiplierBuffer { float m[]; };
2121

22-
shared BatchInfo sharedInfo;
22+
shared uint sharedDim;
23+
shared uint sharedXyOffset;
24+
shared uint sharedMOffset;
2325

2426
void main() {
2527
const uint threadIndex = gl_LocalInvocationID.x;
26-
const uint batchIndex = gl_GlobalInvocationID.y;
2728

2829
if (threadIndex == 0) {
29-
sharedInfo = infos[batchIndex];
30+
const uint nWorkGroups = gl_NumWorkGroups.z;
31+
const uint batchIndex = gl_WorkGroupID.y;
32+
const uint workGroupIndex = gl_WorkGroupID.z;
33+
34+
const BatchInfo info = infos[batchIndex];
35+
sharedDim = info.inputSizeX / nWorkGroups;
36+
sharedXyOffset = info.inputOffset + sharedDim * workGroupIndex;
37+
sharedMOffset = info.inputSizeX * batchIndex + sharedDim * workGroupIndex;
3038
}
31-
memoryBarrierShared();
39+
3240
barrier();
41+
memoryBarrierShared();
3342

34-
const uint inputSizeX = sharedInfo.inputSizeX;
35-
const uint xyOffset = sharedInfo.inputOffset;
36-
const uint mOffset = inputSizeX * batchIndex;
43+
const uint dim = sharedDim;
44+
const uint xyOffset = sharedXyOffset;
45+
const uint mOffset = sharedMOffset;
3746

38-
for (uint i = threadIndex; i < inputSizeX; i += N_THREADS) {
47+
for (uint i = threadIndex; i < dim; i += N_THREADS) {
3948
y[xyOffset + i] = x[xyOffset + i] * m[mOffset + i];
4049
}
4150
}

src/nn/vulkan/silu-forward-f32-f32.comp

+16-6
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,33 @@ layout(binding = 0) readonly buffer inputBuffer { float x[]; };
1515
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
1616
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
1717

18-
shared BatchInfo sharedInfo;
18+
shared uint sharedDim;
19+
shared uint sharedXOffset;
20+
shared uint sharedYOffset;
1921

2022
void main() {
2123
const uint threadIndex = gl_LocalInvocationID.x;
2224
const uint batchIndex = gl_GlobalInvocationID.y;
2325

2426
if (threadIndex == 0) {
25-
sharedInfo = infos[batchIndex];
27+
const uint nWorkGroups = gl_NumWorkGroups.z;
28+
const uint batchIndex = gl_WorkGroupID.y;
29+
const uint workGroupIndex = gl_WorkGroupID.z;
30+
31+
const BatchInfo info = infos[batchIndex];
32+
sharedDim = info.inputSizeX / nWorkGroups;
33+
sharedXOffset = info.inputOffset + sharedDim * workGroupIndex;
34+
sharedYOffset = info.outputOffset + sharedDim * workGroupIndex;
2635
}
2736

2837
barrier();
38+
memoryBarrierShared();
2939

30-
const uint inputSizeX = sharedInfo.inputSizeX;
31-
const uint xOffset = sharedInfo.inputOffset;
32-
const uint yOffset = sharedInfo.outputOffset;
40+
const uint dim = sharedDim;
41+
const uint xOffset = sharedXOffset;
42+
const uint yOffset = sharedYOffset;
3343

34-
for (uint i = threadIndex; i < inputSizeX; i += N_THREADS) {
44+
for (uint i = threadIndex; i < dim; i += N_THREADS) {
3545
float v = x[xOffset + i];
3646
y[yOffset + i] = v / (1.0 + exp(-v));
3747
}

0 commit comments

Comments
 (0)