2
2
3
3
#extension GL_EXT_control_flow_attributes : enable
4
4
5
- #define N_THREADS 256
5
+ #define N_THREADS 64
6
6
7
7
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
8
8
@@ -17,29 +17,42 @@ layout(binding = 0) readonly buffer inputBuffer { float x[]; };
17
17
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
18
18
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
19
19
20
- shared BatchInfo sharedInfo;
20
+ shared uint sharedDim;
21
+ shared uint sharedOutputSizeX;
22
+ shared uint sharedParts;
23
+ shared uint sharedXOffset;
24
+ shared uint sharedYOffset;
21
25
22
26
void main() {
23
27
const uint threadIndex = gl_LocalInvocationID.x;
24
- const uint batchIndex = gl_GlobalInvocationID.y;
25
28
26
29
if (threadIndex == 0) {
27
- sharedInfo = infos[batchIndex];
30
+ const uint nWorkGroups = gl_NumWorkGroups.z;
31
+ const uint batchIndex = gl_WorkGroupID.y;
32
+ const uint workGroupIndex = gl_WorkGroupID.z;
33
+
34
+ const BatchInfo info = infos[batchIndex];
35
+ sharedDim = info.outputSizeX / nWorkGroups;
36
+ sharedOutputSizeX = info.outputSizeX;
37
+ sharedParts = info.inputSizeX / info.outputSizeX;
38
+ sharedXOffset = info.inputOffset + sharedDim * workGroupIndex;
39
+ sharedYOffset = info.outputOffset + sharedDim * workGroupIndex;
28
40
}
29
- memoryBarrierShared();
41
+
30
42
barrier();
43
+ memoryBarrierShared();
31
44
32
- const uint inputSizeX = sharedInfo.inputSizeX ;
33
- const uint inputOffset = sharedInfo.inputOffset ;
34
- const uint outputOffset = sharedInfo.outputOffset ;
35
- const uint outputSizeX = sharedInfo.outputSizeX ;
36
- const uint nNodes = inputSizeX / outputSizeX ;
45
+ const uint dim = sharedDim ;
46
+ const uint outputSizeX = sharedOutputSizeX ;
47
+ const uint parts = sharedParts ;
48
+ const uint xOffset = sharedXOffset ;
49
+ const uint yOffset = sharedYOffset ;
37
50
38
- for (uint i = threadIndex; i < outputSizeX ; i += N_THREADS) {
51
+ for (uint i = threadIndex; i < dim ; i += N_THREADS) {
39
52
float sum = 0.0;
40
- const uint iOffset = inputOffset + i;
41
- const uint oOffset = outputOffset + i;
42
- for (uint n = 0; n < nNodes ; n++) {
53
+ const uint iOffset = xOffset + i;
54
+ const uint oOffset = yOffset + i;
55
+ for (uint n = 0; n < parts ; n++) {
43
56
sum += x[n * outputSizeX + iOffset];
44
57
}
45
58
y[oOffset] += sum;
0 commit comments