4
4
#extension GL_EXT_shader_16bit_storage : enable
5
5
#extension GL_EXT_shader_explicit_arithmetic_types : enable
6
6
7
- #define Q80_BLOCK_SIZE 32
8
- #define Q40_BLOCK_SIZE 32
9
- #define N_THREADS 128
7
+ #define Q80_Q40_BLOCK_SIZE 32
8
+ #define N_THREADS 256
9
+
10
+ #define GROUP_SIZE 64
11
+ #define N_THREADS_PER_GROUP (N_THREADS / GROUP_SIZE)
10
12
11
13
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
12
14
@@ -19,12 +21,12 @@ struct BatchInfo {
19
21
20
22
struct BlockQ80 {
21
23
float16_t d;
22
- int8_t qs[Q80_BLOCK_SIZE ];
24
+ int8_t qs[Q80_Q40_BLOCK_SIZE ];
23
25
};
24
26
25
27
struct BlockQ40 {
26
28
float16_t d;
27
- uint8_t qs[Q40_BLOCK_SIZE / 2];
29
+ uint8_t qs[Q80_Q40_BLOCK_SIZE / 2];
28
30
};
29
31
30
32
layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; };
@@ -34,7 +36,11 @@ layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
34
36
35
37
shared uint sharedStart;
36
38
shared uint sharedEnd;
37
- shared BatchInfo sharedInfo;
39
+ shared uint sharedInputOffset;
40
+ shared uint sharedInputSizeX;
41
+ shared uint sharedOutputOffset;
42
+ shared uint sharedInputSizeXPerGroup;
43
+ shared float16_t sums[N_THREADS];
38
44
39
45
void main() {
40
46
const uint threadIndex = gl_LocalInvocationID.x;
@@ -44,44 +50,62 @@ void main() {
44
50
const uint batchIndex = gl_WorkGroupID.y;
45
51
const uint workGroupIndex = gl_WorkGroupID.z;
46
52
47
- const BatchInfo info = infos[batchIndex];
53
+ sharedInputOffset = infos[batchIndex].inputOffset;
54
+ sharedInputSizeX = infos[batchIndex].inputSizeX;
55
+ sharedOutputOffset = infos[batchIndex].outputOffset;
56
+ sharedInputSizeXPerGroup = (sharedInputSizeX + N_THREADS_PER_GROUP - 1) / N_THREADS_PER_GROUP;
48
57
49
- const uint ySlice = info .outputSizeX / nWorkGroups;
50
- const uint yRest = info .outputSizeX % nWorkGroups;
58
+ const uint ySlice = infos[batchIndex] .outputSizeX / nWorkGroups;
59
+ const uint yRest = infos[batchIndex] .outputSizeX % nWorkGroups;
51
60
sharedStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
52
61
sharedEnd = sharedStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
53
- sharedInfo = info;
54
62
}
55
63
56
64
barrier();
57
65
memoryBarrierShared();
58
66
59
- const uint end = sharedEnd;
60
- const uint inputOffset = sharedInfo.inputOffset;
61
- const uint inputSizeX = sharedInfo.inputSizeX;
62
- const uint outputOffset = sharedInfo.outputOffset;
67
+ const uint dEnd = sharedEnd;
68
+ const uint inputOffset = sharedInputOffset;
69
+ const uint inputSizeX = sharedInputSizeX;
70
+ const uint outputOffset = sharedOutputOffset;
71
+ const uint inputSizeXPerGroup = sharedInputSizeXPerGroup;
72
+
73
+ const uint dGroup = threadIndex / N_THREADS_PER_GROUP;
74
+ const uint iGroup = threadIndex % N_THREADS_PER_GROUP;
75
+ const uint iStart = inputSizeXPerGroup * iGroup;
76
+ const uint iEnd = min(iStart + inputSizeXPerGroup, inputSizeX);
77
+
78
+ for (uint dBatch = sharedStart; dBatch < dEnd; dBatch += GROUP_SIZE) {
79
+ const uint d = dBatch + dGroup;
80
+ if (d >= dEnd) {
81
+ break;
82
+ }
63
83
64
- for (uint d = sharedStart + threadIndex; d < end; d += N_THREADS) {
65
84
float16_t sum = float16_t(0.0f);
66
- const uint wOffset = d * inputSizeX;
67
-
68
- for (uint i = 0; i < inputSizeX; i++) {
69
- const BlockQ80 xi = x[inputOffset + i];
70
- const BlockQ40 wi = weight[wOffset + i];
85
+ for (uint i = iStart; i < iEnd; i++) {
86
+ const uint xi = inputOffset + i;
87
+ const uint wi = d * inputSizeX + i;
88
+ [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 2; j++) {
89
+ sum += (
90
+ float16_t(x[xi].qs[j]) * (float16_t(weight[wi].qs[j] & 0xF) - float16_t(8.0f)) +
91
+ float16_t(x[xi].qs[j + Q80_Q40_BLOCK_SIZE / 2]) * (float16_t(weight[wi].qs[j] >> 4) - float16_t(8.0f))
92
+ ) * x[xi].d * weight[wi].d;
93
+ }
94
+ }
95
+ sums[threadIndex] = sum;
71
96
72
- float16_t s = float16_t(0.0f);
73
- [[unroll]] for (uint j = 0; j < Q40_BLOCK_SIZE / 2; j++) {
74
- const float16_t x0 = float16_t(xi.qs[j]);
75
- const float16_t x1 = float16_t(xi.qs[j + Q80_BLOCK_SIZE / 2]);
97
+ barrier();
98
+ memoryBarrierShared();
76
99
77
- const uint8_t wq = wi.qs[j];
78
- const float16_t w0 = float16_t(wq & 0xF) - float16_t(8.0f);
79
- const float16_t w1 = float16_t(wq >> 4) - float16_t(8.0f);
80
- s += x0 * w0 + x1 * w1;
81
- }
82
- sum += s * xi.d * wi.d;
100
+ [[unroll]] for (uint i = N_THREADS_PER_GROUP / 2; i > 0; i >>= 1) {
101
+ if (iGroup < i)
102
+ sums[threadIndex] += sums[threadIndex + i];
103
+ barrier();
104
+ }
105
+ if (iGroup == 0) {
106
+ y[outputOffset + d] = float(sums[threadIndex]);
83
107
}
84
108
85
- y[outputOffset + d] = float(sum );
109
+ barrier( );
86
110
}
87
111
}
0 commit comments