Skip to content

Commit 0b82d94

Browse files
committed
optimization.
1 parent ddde606 commit 0b82d94

4 files changed

+43
-21
lines changed

src/nn/nn-vulkan-test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ void testEmbedding_F32_F32() {
266266
}
267267

268268
void testShift_F32_F32() {
269-
#define SHIFT_DIM 48
269+
#define SHIFT_DIM 64
270270
execute(
271271
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
272272
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));

src/nn/nn-vulkan.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,10 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
459459
groupCount[2] = 1;
460460

461461
if (opConfig->code == OP_CAST ||
462+
opConfig->code == OP_RMS_NORM ||
462463
opConfig->code == OP_MUL ||
463464
opConfig->code == OP_SILU ||
465+
opConfig->code == OP_SHIFT ||
464466
opConfig->code == OP_MERGE_ADD ||
465467
opConfig->code == OP_MATMUL)
466468
groupCount[2] = 32;

src/nn/vulkan/rms-norm-forward-f32-f32-f32.comp

+23-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#define N_THREADS 64
44

55
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
66

@@ -20,26 +20,38 @@ layout(binding = 4) readonly uniform configBuffer {
2020
};
2121
layout(binding = 5) readonly buffer invRmsBuffer { float invRms[]; };
2222

23-
shared BatchInfo sharedInfo;
24-
shared float s;
23+
shared uint sharedDim;
24+
shared uint sharedDimOffset;
25+
shared uint sharedXOffset;
26+
shared uint sharedYOffset;
27+
shared float sharedS;
2528

2629
void main() {
2730
const uint threadIndex = uint(gl_LocalInvocationID.x);
28-
const uint batchIndex = uint(gl_GlobalInvocationID.y);
2931

3032
if (threadIndex == 0) {
31-
sharedInfo = infos[batchIndex];
32-
s = invRms[batchIndex];
33+
const uint nWorkGroups = gl_NumWorkGroups.z;
34+
const uint batchIndex = gl_WorkGroupID.y;
35+
const uint workGroupIndex = gl_WorkGroupID.z;
36+
37+
const BatchInfo info = infos[batchIndex];
38+
sharedDim = info.inputSizeX / nWorkGroups;
39+
sharedDimOffset = sharedDim * workGroupIndex;
40+
sharedXOffset = info.inputOffset + sharedDimOffset;
41+
sharedYOffset = info.outputOffset + sharedDimOffset;
42+
sharedS = invRms[batchIndex];
3343
}
3444

3545
barrier();
3646
memoryBarrierShared();
3747

38-
const uint inputSizeX = sharedInfo.inputSizeX;
39-
const uint xOffset = sharedInfo.inputOffset;
40-
const uint yOffset = sharedInfo.outputOffset;
48+
const uint dim = sharedDim;
49+
const uint dimOffset = sharedDimOffset;
50+
const uint xOffset = sharedXOffset;
51+
const uint yOffset = sharedYOffset;
52+
const float s = sharedS;
4153

42-
for (uint i = threadIndex; i < inputSizeX; i += N_THREADS) {
43-
y[yOffset + i] = (x[xOffset + i] * s) * weight[i];
54+
for (uint i = threadIndex; i < dim; i += N_THREADS) {
55+
y[yOffset + i] = (x[xOffset + i] * s) * weight[i + dimOffset];
4456
}
4557
}

src/nn/vulkan/shift-forward-f32-f32.comp

+17-9
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,34 @@ layout(binding = 3) readonly uniform configBuffer {
1919
};
2020
layout(binding = 4) readonly buffer indexBuffer { float indexes[]; };
2121

22-
shared uint sharedIndex;
23-
shared BatchInfo sharedInfo;
22+
shared uint sharedDim;
23+
shared uint sharedXOffset;
24+
shared uint sharedYOffset;
2425

2526
void main() {
2627
const uint threadIndex = gl_LocalInvocationID.x;
27-
const uint batchIndex = gl_GlobalInvocationID.y;
2828

2929
if (threadIndex == 0) {
30-
sharedIndex = uint(indexes[batchIndex]);
31-
sharedInfo = infos[batchIndex];
30+
const uint nWorkGroups = gl_NumWorkGroups.z;
31+
const uint batchIndex = gl_WorkGroupID.y;
32+
const uint workGroupIndex = gl_WorkGroupID.z;
33+
34+
const uint index = uint(indexes[batchIndex]);
35+
BatchInfo info = infos[batchIndex];
36+
sharedDim = info.inputSizeX / nWorkGroups;
37+
const uint dimOffset = sharedDim * workGroupIndex;
38+
sharedXOffset = info.inputOffset + dimOffset;
39+
sharedYOffset = index * info.inputSizeX + dimOffset;
3240
}
3341

3442
barrier();
3543
memoryBarrierShared();
3644

37-
const uint inputSizeX = sharedInfo.inputSizeX;
38-
const uint xOffset = sharedInfo.inputOffset;
39-
const uint yOffset = sharedIndex * inputSizeX;
45+
const uint dim = sharedDim;
46+
const uint xOffset = sharedXOffset;
47+
const uint yOffset = sharedYOffset;
4048

41-
for (uint i = threadIndex; i < inputSizeX; i += N_THREADS) {
49+
for (uint i = threadIndex; i < dim; i += N_THREADS) {
4250
y[yOffset + i] = x[xOffset + i];
4351
}
4452
}

0 commit comments

Comments
 (0)