Skip to content

Commit 8909be9

Browse files
authored
feat: optimize vulkan tiled quantized matmul. (#200)
1 parent 50dfb13 commit 8909be9

File tree

3 files changed

+87
-54
lines changed

3 files changed

+87
-54
lines changed

src/nn/nn-vulkan-test.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,8 @@ void testMatmul_F32_F32_F32() {
529529
}
530530

531531
void testMatmul_Q80_Q40_F32() {
532-
#define MATMUL_Q80_Q40_N 512
533-
#define MATMUL_Q80_Q40_D 512
532+
#define MATMUL_Q80_Q40_N 4096
533+
#define MATMUL_Q80_Q40_D 4096
534534
execute(
535535
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
536536
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_Q80, N_BATCHES, MATMUL_Q80_Q40_N));
@@ -557,9 +557,9 @@ void testMatmul_Q80_Q40_F32() {
557557
std::unique_ptr<NnBlockQ40[]> weightQ40(new NnBlockQ40[weightBlocks]);
558558

559559
for (NnUint i = 0; i < xSize; i++)
560-
x[i] = i * 0.001f;
560+
x[i] = i * 0.00001f;
561561
for (NnUint i = 0; i < weightSize; i++)
562-
weight[i] = i * 0.0001f;
562+
weight[i] = i * 0.000001f;
563563

564564
quantizeF32toQ80(x.get(), xPipe, xSize, 1, 0);
565565
quantizeF32toQ40(weight.get(), weightQ40.get(), weightSize, 1, 0);
@@ -576,7 +576,7 @@ void testMatmul_Q80_Q40_F32() {
576576
for (NnUint n = 0; n < MATMUL_Q80_Q40_N; n++)
577577
sum += x[b * MATMUL_Q80_Q40_N + n] * weight[d * MATMUL_Q80_Q40_N + n];
578578
const NnUint p = b * MATMUL_Q80_Q40_D + d;
579-
const float tolerance = sum * 0.025f;
579+
const float tolerance = sum * 0.035f;
580580
assertFloat(p, yPipe[p], sum, tolerance);
581581
}
582582
}

src/nn/nn-vulkan.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,17 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
499499
opConfig->code == OP_MUL ||
500500
opConfig->code == OP_SILU ||
501501
opConfig->code == OP_SHIFT ||
502-
opConfig->code == OP_MERGE_ADD ||
503-
opConfig->code == OP_MATMUL)
502+
opConfig->code == OP_MERGE_ADD)
504503
groupCount[2] = 32;
504+
else if (opConfig->code == OP_MATMUL) {
505+
if (opConfig->weightSize.floatType == F_Q40) {
506+
constexpr NnUint tileSizeD = 16; // Must be synced with the shader
507+
assert(opConfig->weightSize.x % tileSizeD == 0);
508+
groupCount[2] = opConfig->weightSize.x / tileSizeD;
509+
} else {
510+
groupCount[2] = 32;
511+
}
512+
}
505513
else if (opConfig->code == OP_MULTIHEAD_ATT)
506514
groupCount[2] = ((NnMultiHeadAttOpConfig *)opConfig->config)->nHeads;
507515
}
@@ -599,6 +607,12 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
599607
std::vector<vk::PipelineShaderStageCreateInfo> shaderCreateInfos(segmentConfig->nOps);
600608
std::vector<std::vector<NnVulkanBuffer *>> opBuffers(segmentConfig->nOps);
601609

610+
constexpr NnUint maxConsts = 3;
611+
std::vector<NnUint> nConsts(segmentConfig->nOps);
612+
std::vector<int> consts(segmentConfig->nOps * maxConsts);
613+
std::vector<vk::SpecializationInfo> specInfos(segmentConfig->nOps);
614+
std::vector<vk::SpecializationMapEntry> specMapEntries(segmentConfig->nOps * maxConsts);
615+
602616
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
603617
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
604618
NnSize2D inputSize = data->resolveBufferSize(&opConfig->input);
@@ -620,6 +634,7 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
620634
code.size(),
621635
code.data()
622636
);
637+
623638
vk::ShaderModule shaderModule = context->device.createShaderModule(shaderModuleCreateInfo);
624639
vk::PipelineShaderStageCreateInfo shaderCreateInfo(
625640
vk::PipelineShaderStageCreateFlags(),

src/nn/vulkan/matmul-forward-q80-q40-f32.comp

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
#extension GL_EXT_shader_16bit_storage : enable
55
#extension GL_EXT_shader_explicit_arithmetic_types : enable
66

7-
#define Q80_Q40_BLOCK_SIZE 32
8-
#define N_THREADS 256
7+
#define N_THREADS 64
8+
#define TILE_SIZE_X 2
9+
#define TILE_SIZE_D 16
910

10-
#define N_OUTPUTS_PER_ITER 64
11-
#define N_THREADS_PER_OUTPUT (N_THREADS / N_OUTPUTS_PER_ITER)
11+
#define Q80_Q40_BLOCK_SIZE 32
1212

1313
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
1414

@@ -34,80 +34,98 @@ layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
3434
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
3535
layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
3636

37-
shared uint sharedStart;
38-
shared uint sharedEnd;
37+
shared uint sharedXSlice;
38+
shared uint sharedXRest;
3939
shared uint sharedInputOffset;
4040
shared uint sharedInputSizeX;
4141
shared uint sharedOutputOffset;
42-
shared uint sharedInputSizeXPerGroup;
43-
shared float16_t sums[N_THREADS];
42+
shared uint sharedD;
43+
shared float16_t sums[N_THREADS * TILE_SIZE_D];
4444

4545
void main() {
4646
const uint threadIndex = gl_LocalInvocationID.x;
4747

4848
if (threadIndex == 0) {
49-
const uint nWorkGroups = gl_NumWorkGroups.z;
5049
const uint batchIndex = gl_WorkGroupID.y;
5150
const uint workGroupIndex = gl_WorkGroupID.z;
5251

5352
const BatchInfo info = infos[batchIndex];
53+
54+
const uint xTiles = info.inputSizeX / TILE_SIZE_X;
55+
sharedXSlice = xTiles / N_THREADS;
56+
sharedXRest = xTiles % N_THREADS;
57+
5458
sharedInputOffset = info.inputOffset;
5559
sharedInputSizeX = info.inputSizeX;
5660
sharedOutputOffset = info.outputOffset;
57-
sharedInputSizeXPerGroup = (sharedInputSizeX + N_THREADS_PER_OUTPUT - 1) / N_THREADS_PER_OUTPUT;
58-
59-
const uint ySlice = info.outputSizeX / nWorkGroups;
60-
const uint yRest = info.outputSizeX % nWorkGroups;
61-
sharedStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
62-
sharedEnd = sharedStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
61+
sharedD = TILE_SIZE_D * workGroupIndex;
6362
}
6463

6564
barrier();
6665
memoryBarrierShared();
6766

68-
const uint dEnd = sharedEnd;
67+
const uint xSlice = sharedXSlice;
68+
const uint xRest = sharedXRest;
69+
const uint xStart = (threadIndex * xSlice + min(threadIndex, xRest)) * TILE_SIZE_X;
70+
const uint xEnd = xStart + (xSlice + (threadIndex < xRest ? 1 : 0)) * TILE_SIZE_X;
71+
6972
const uint inputOffset = sharedInputOffset;
7073
const uint inputSizeX = sharedInputSizeX;
7174
const uint outputOffset = sharedOutputOffset;
72-
const uint inputSizeXPerGroup = sharedInputSizeXPerGroup;
75+
const uint d = sharedD;
7376

74-
const uint dGroup = threadIndex / N_THREADS_PER_OUTPUT;
75-
const uint iGroup = threadIndex % N_THREADS_PER_OUTPUT;
76-
const uint iStart = inputSizeXPerGroup * iGroup;
77-
const uint iEnd = min(iStart + inputSizeXPerGroup, inputSizeX);
77+
f16vec4 xTemp[Q80_Q40_BLOCK_SIZE / 4];
7878

79-
for (uint dBatch = sharedStart; dBatch < dEnd; dBatch += N_OUTPUTS_PER_ITER) {
80-
const uint d = dBatch + dGroup;
81-
if (d >= dEnd) {
82-
break;
83-
}
79+
for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
80+
sums[threadIndex * TILE_SIZE_D + dt] = float16_t(0.0f);
81+
}
82+
83+
for (uint i = xStart; i < xEnd; i += TILE_SIZE_X) {
84+
[[unroll]] for (uint it = 0; it < TILE_SIZE_X; it++) {
85+
const uint xi = inputOffset + i + it;
86+
const float16_t xScale = x[xi].d;
87+
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
88+
xTemp[j] = f16vec4(
89+
x[xi].qs[j * 2],
90+
x[xi].qs[j * 2 + Q80_Q40_BLOCK_SIZE / 2],
91+
x[xi].qs[j * 2 + 1],
92+
x[xi].qs[j * 2 + 1 + Q80_Q40_BLOCK_SIZE / 2]
93+
);
94+
}
8495

85-
float16_t sum = float16_t(0.0f);
86-
for (uint i = iStart; i < iEnd; i++) {
87-
const uint xi = inputOffset + i;
88-
const uint wi = d * inputSizeX + i;
89-
const float16_t scale = x[xi].d * weight[wi].d;
90-
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 2; j++) {
91-
sum += (
92-
float16_t(x[xi].qs[j]) * (float16_t(weight[wi].qs[j] & 0xF) - float16_t(8.0f)) +
93-
float16_t(x[xi].qs[j + Q80_Q40_BLOCK_SIZE / 2]) * (float16_t(weight[wi].qs[j] >> 4) - float16_t(8.0f))
94-
) * scale;
96+
[[unroll]] for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
97+
const uint wi = (d + dt) * inputSizeX + (i + it);
98+
const BlockQ40 wBlock = weight[wi];
99+
100+
float16_t s = float16_t(0);
101+
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
102+
uint w0 = wBlock.qs[j * 2];
103+
uint w1 = wBlock.qs[j * 2 + 1];
104+
ivec4 w = ivec4(
105+
w0 & 0xFu,
106+
w0 >> 4,
107+
w1 & 0xFu,
108+
w1 >> 4
109+
) - ivec4(8);
110+
s += dot(xTemp[j], f16vec4(w));
111+
}
112+
sums[threadIndex * TILE_SIZE_D + dt] += s * xScale * wBlock.d;
95113
}
96114
}
97-
sums[threadIndex] = sum;
115+
}
98116

99-
barrier();
100-
memoryBarrierShared();
117+
barrier();
118+
memoryBarrierShared();
101119

102-
[[unroll]] for (uint i = N_THREADS_PER_OUTPUT / 2; i > 0; i >>= 1) {
103-
if (iGroup < i)
104-
sums[threadIndex] += sums[threadIndex + i];
105-
barrier();
106-
}
107-
if (iGroup == 0) {
108-
y[outputOffset + d] = float(sums[threadIndex]);
120+
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
121+
for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
122+
if (threadIndex < i) {
123+
sums[threadIndex * TILE_SIZE_D + dt] += sums[(threadIndex + i) * TILE_SIZE_D + dt];
124+
}
109125
}
110-
111126
barrier();
112127
}
128+
for (uint dt = threadIndex; dt < TILE_SIZE_D; dt += N_THREADS) {
129+
y[outputOffset + d + dt] = float(sums[dt]);
130+
}
113131
}

0 commit comments

Comments
 (0)