Skip to content

Commit e208f50

Browse files
authored
feat: vulkan matmul optimization. (#192)
1 parent 12df7b4 commit e208f50

File tree

1 file changed

+55
-31
lines changed

1 file changed

+55
-31
lines changed

src/nn/vulkan/matmul-forward-q80-q40-f32.comp

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
#extension GL_EXT_shader_16bit_storage : enable
55
#extension GL_EXT_shader_explicit_arithmetic_types : enable
66

7-
#define Q80_BLOCK_SIZE 32
8-
#define Q40_BLOCK_SIZE 32
9-
#define N_THREADS 128
7+
#define Q80_Q40_BLOCK_SIZE 32
8+
#define N_THREADS 256
9+
10+
#define GROUP_SIZE 64
11+
#define N_THREADS_PER_GROUP (N_THREADS / GROUP_SIZE)
1012

1113
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
1214

@@ -19,12 +21,12 @@ struct BatchInfo {
1921

2022
struct BlockQ80 {
2123
float16_t d;
22-
int8_t qs[Q80_BLOCK_SIZE];
24+
int8_t qs[Q80_Q40_BLOCK_SIZE];
2325
};
2426

2527
struct BlockQ40 {
2628
float16_t d;
27-
uint8_t qs[Q40_BLOCK_SIZE / 2];
29+
uint8_t qs[Q80_Q40_BLOCK_SIZE / 2];
2830
};
2931

3032
layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; };
@@ -34,7 +36,11 @@ layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
3436

3537
shared uint sharedStart;
3638
shared uint sharedEnd;
37-
shared BatchInfo sharedInfo;
39+
shared uint sharedInputOffset;
40+
shared uint sharedInputSizeX;
41+
shared uint sharedOutputOffset;
42+
shared uint sharedInputSizeXPerGroup;
43+
shared float16_t sums[N_THREADS];
3844

3945
void main() {
4046
const uint threadIndex = gl_LocalInvocationID.x;
@@ -44,44 +50,62 @@ void main() {
4450
const uint batchIndex = gl_WorkGroupID.y;
4551
const uint workGroupIndex = gl_WorkGroupID.z;
4652

47-
const BatchInfo info = infos[batchIndex];
53+
sharedInputOffset = infos[batchIndex].inputOffset;
54+
sharedInputSizeX = infos[batchIndex].inputSizeX;
55+
sharedOutputOffset = infos[batchIndex].outputOffset;
56+
sharedInputSizeXPerGroup = (sharedInputSizeX + N_THREADS_PER_GROUP - 1) / N_THREADS_PER_GROUP;
4857

49-
const uint ySlice = info.outputSizeX / nWorkGroups;
50-
const uint yRest = info.outputSizeX % nWorkGroups;
58+
const uint ySlice = infos[batchIndex].outputSizeX / nWorkGroups;
59+
const uint yRest = infos[batchIndex].outputSizeX % nWorkGroups;
5160
sharedStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
5261
sharedEnd = sharedStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
53-
sharedInfo = info;
5462
}
5563

5664
barrier();
5765
memoryBarrierShared();
5866

59-
const uint end = sharedEnd;
60-
const uint inputOffset = sharedInfo.inputOffset;
61-
const uint inputSizeX = sharedInfo.inputSizeX;
62-
const uint outputOffset = sharedInfo.outputOffset;
67+
const uint dEnd = sharedEnd;
68+
const uint inputOffset = sharedInputOffset;
69+
const uint inputSizeX = sharedInputSizeX;
70+
const uint outputOffset = sharedOutputOffset;
71+
const uint inputSizeXPerGroup = sharedInputSizeXPerGroup;
72+
73+
const uint dGroup = threadIndex / N_THREADS_PER_GROUP;
74+
const uint iGroup = threadIndex % N_THREADS_PER_GROUP;
75+
const uint iStart = inputSizeXPerGroup * iGroup;
76+
const uint iEnd = min(iStart + inputSizeXPerGroup, inputSizeX);
77+
78+
for (uint dBatch = sharedStart; dBatch < dEnd; dBatch += GROUP_SIZE) {
79+
const uint d = dBatch + dGroup;
80+
if (d >= dEnd) {
81+
break;
82+
}
6383

64-
for (uint d = sharedStart + threadIndex; d < end; d += N_THREADS) {
6584
float16_t sum = float16_t(0.0f);
66-
const uint wOffset = d * inputSizeX;
67-
68-
for (uint i = 0; i < inputSizeX; i++) {
69-
const BlockQ80 xi = x[inputOffset + i];
70-
const BlockQ40 wi = weight[wOffset + i];
85+
for (uint i = iStart; i < iEnd; i++) {
86+
const uint xi = inputOffset + i;
87+
const uint wi = d * inputSizeX + i;
88+
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 2; j++) {
89+
sum += (
90+
float16_t(x[xi].qs[j]) * (float16_t(weight[wi].qs[j] & 0xF) - float16_t(8.0f)) +
91+
float16_t(x[xi].qs[j + Q80_Q40_BLOCK_SIZE / 2]) * (float16_t(weight[wi].qs[j] >> 4) - float16_t(8.0f))
92+
) * x[xi].d * weight[wi].d;
93+
}
94+
}
95+
sums[threadIndex] = sum;
7196

72-
float16_t s = float16_t(0.0f);
73-
[[unroll]] for (uint j = 0; j < Q40_BLOCK_SIZE / 2; j++) {
74-
const float16_t x0 = float16_t(xi.qs[j]);
75-
const float16_t x1 = float16_t(xi.qs[j + Q80_BLOCK_SIZE / 2]);
97+
barrier();
98+
memoryBarrierShared();
7699

77-
const uint8_t wq = wi.qs[j];
78-
const float16_t w0 = float16_t(wq & 0xF) - float16_t(8.0f);
79-
const float16_t w1 = float16_t(wq >> 4) - float16_t(8.0f);
80-
s += x0 * w0 + x1 * w1;
81-
}
82-
sum += s * xi.d * wi.d;
100+
[[unroll]] for (uint i = N_THREADS_PER_GROUP / 2; i > 0; i >>= 1) {
101+
if (iGroup < i)
102+
sums[threadIndex] += sums[threadIndex + i];
103+
barrier();
104+
}
105+
if (iGroup == 0) {
106+
y[outputOffset + d] = float(sums[threadIndex]);
83107
}
84108

85-
y[outputOffset + d] = float(sum);
109+
barrier();
86110
}
87111
}

0 commit comments

Comments
 (0)