4
4
#extension GL_EXT_shader_16bit_storage : enable
5
5
#extension GL_EXT_shader_explicit_arithmetic_types : enable
6
6
7
- #define Q80_Q40_BLOCK_SIZE 32
8
- #define N_THREADS 256
7
+ #define N_THREADS 64
8
+ #define TILE_SIZE_X 2
9
+ #define TILE_SIZE_D 16
9
10
10
- #define N_OUTPUTS_PER_ITER 64
11
- #define N_THREADS_PER_OUTPUT (N_THREADS / N_OUTPUTS_PER_ITER)
11
+ #define Q80_Q40_BLOCK_SIZE 32
12
12
13
13
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
14
14
@@ -34,80 +34,98 @@ layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
34
34
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
35
35
layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
36
36
37
- shared uint sharedStart ;
38
- shared uint sharedEnd ;
37
+ shared uint sharedXSlice ;
38
+ shared uint sharedXRest ;
39
39
shared uint sharedInputOffset;
40
40
shared uint sharedInputSizeX;
41
41
shared uint sharedOutputOffset;
42
- shared uint sharedInputSizeXPerGroup ;
43
- shared float16_t sums[N_THREADS];
42
+ shared uint sharedD ;
43
+ shared float16_t sums[N_THREADS * TILE_SIZE_D ];
44
44
45
45
void main() {
46
46
const uint threadIndex = gl_LocalInvocationID.x;
47
47
48
48
if (threadIndex == 0) {
49
- const uint nWorkGroups = gl_NumWorkGroups.z;
50
49
const uint batchIndex = gl_WorkGroupID.y;
51
50
const uint workGroupIndex = gl_WorkGroupID.z;
52
51
53
52
const BatchInfo info = infos[batchIndex];
53
+
54
+ const uint xTiles = info.inputSizeX / TILE_SIZE_X;
55
+ sharedXSlice = xTiles / N_THREADS;
56
+ sharedXRest = xTiles % N_THREADS;
57
+
54
58
sharedInputOffset = info.inputOffset;
55
59
sharedInputSizeX = info.inputSizeX;
56
60
sharedOutputOffset = info.outputOffset;
57
- sharedInputSizeXPerGroup = (sharedInputSizeX + N_THREADS_PER_OUTPUT - 1) / N_THREADS_PER_OUTPUT;
58
-
59
- const uint ySlice = info.outputSizeX / nWorkGroups;
60
- const uint yRest = info.outputSizeX % nWorkGroups;
61
- sharedStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
62
- sharedEnd = sharedStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
61
+ sharedD = TILE_SIZE_D * workGroupIndex;
63
62
}
64
63
65
64
barrier();
66
65
memoryBarrierShared();
67
66
68
- const uint dEnd = sharedEnd;
67
+ const uint xSlice = sharedXSlice;
68
+ const uint xRest = sharedXRest;
69
+ const uint xStart = (threadIndex * xSlice + min(threadIndex, xRest)) * TILE_SIZE_X;
70
+ const uint xEnd = xStart + (xSlice + (threadIndex < xRest ? 1 : 0)) * TILE_SIZE_X;
71
+
69
72
const uint inputOffset = sharedInputOffset;
70
73
const uint inputSizeX = sharedInputSizeX;
71
74
const uint outputOffset = sharedOutputOffset;
72
- const uint inputSizeXPerGroup = sharedInputSizeXPerGroup ;
75
+ const uint d = sharedD ;
73
76
74
- const uint dGroup = threadIndex / N_THREADS_PER_OUTPUT;
75
- const uint iGroup = threadIndex % N_THREADS_PER_OUTPUT;
76
- const uint iStart = inputSizeXPerGroup * iGroup;
77
- const uint iEnd = min(iStart + inputSizeXPerGroup, inputSizeX);
77
+ f16vec4 xTemp[Q80_Q40_BLOCK_SIZE / 4];
78
78
79
- for (uint dBatch = sharedStart; dBatch < dEnd; dBatch += N_OUTPUTS_PER_ITER) {
80
- const uint d = dBatch + dGroup;
81
- if (d >= dEnd) {
82
- break;
83
- }
79
+ for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
80
+ sums[threadIndex * TILE_SIZE_D + dt] = float16_t(0.0f);
81
+ }
82
+
83
+ for (uint i = xStart; i < xEnd; i += TILE_SIZE_X) {
84
+ [[unroll]] for (uint it = 0; it < TILE_SIZE_X; it++) {
85
+ const uint xi = inputOffset + i + it;
86
+ const float16_t xScale = x[xi].d;
87
+ [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
88
+ xTemp[j] = f16vec4(
89
+ x[xi].qs[j * 2],
90
+ x[xi].qs[j * 2 + Q80_Q40_BLOCK_SIZE / 2],
91
+ x[xi].qs[j * 2 + 1],
92
+ x[xi].qs[j * 2 + 1 + Q80_Q40_BLOCK_SIZE / 2]
93
+ );
94
+ }
84
95
85
- float16_t sum = float16_t(0.0f);
86
- for (uint i = iStart; i < iEnd; i++) {
87
- const uint xi = inputOffset + i;
88
- const uint wi = d * inputSizeX + i;
89
- const float16_t scale = x[xi].d * weight[wi].d;
90
- [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 2; j++) {
91
- sum += (
92
- float16_t(x[xi].qs[j]) * (float16_t(weight[wi].qs[j] & 0xF) - float16_t(8.0f)) +
93
- float16_t(x[xi].qs[j + Q80_Q40_BLOCK_SIZE / 2]) * (float16_t(weight[wi].qs[j] >> 4) - float16_t(8.0f))
94
- ) * scale;
96
+ [[unroll]] for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
97
+ const uint wi = (d + dt) * inputSizeX + (i + it);
98
+ const BlockQ40 wBlock = weight[wi];
99
+
100
+ float16_t s = float16_t(0);
101
+ [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
102
+ uint w0 = wBlock.qs[j * 2];
103
+ uint w1 = wBlock.qs[j * 2 + 1];
104
+ ivec4 w = ivec4(
105
+ w0 & 0xFu,
106
+ w0 >> 4,
107
+ w1 & 0xFu,
108
+ w1 >> 4
109
+ ) - ivec4(8);
110
+ s += dot(xTemp[j], f16vec4(w));
111
+ }
112
+ sums[threadIndex * TILE_SIZE_D + dt] += s * xScale * wBlock.d;
95
113
}
96
114
}
97
- sums[threadIndex] = sum;
115
+ }
98
116
99
- barrier();
100
- memoryBarrierShared();
117
+ barrier();
118
+ memoryBarrierShared();
101
119
102
- [[unroll]] for (uint i = N_THREADS_PER_OUTPUT / 2; i > 0; i >>= 1) {
103
- if (iGroup < i)
104
- sums[threadIndex] += sums[threadIndex + i];
105
- barrier();
106
- }
107
- if (iGroup == 0) {
108
- y[outputOffset + d] = float(sums[threadIndex]);
120
+ [[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
121
+ for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
122
+ if (threadIndex < i) {
123
+ sums[threadIndex * TILE_SIZE_D + dt] += sums[(threadIndex + i) * TILE_SIZE_D + dt];
124
+ }
109
125
}
110
-
111
126
barrier();
112
127
}
128
+ for (uint dt = threadIndex; dt < TILE_SIZE_D; dt += N_THREADS) {
129
+ y[outputOffset + d + dt] = float(sums[dt]);
130
+ }
113
131
}
0 commit comments