Skip to content

Commit bee26bb

Browse files
committed
testMergeAdd_Q80_F32.
1 parent da41445 commit bee26bb

6 files changed

+146
-13
lines changed

src/nn/nn-vulkan-test.cpp

+53-9
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,12 @@ void testMul_F32_F32() {
182182
}
183183

184184
void testMergeAdd_F32_F32() {
185-
#define MERGE_ADD_NODES 2
186-
#define MERGE_ADD_DIM 64
185+
#define MERGE_ADD_F32_NODES 2
186+
#define MERGE_ADD_F32_DIM 64
187187
execute(
188188
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
189-
NnUint zPipeIndex = netBuilder->addPipe("Z", size2D(F_32, N_BATCHES, MERGE_ADD_DIM * MERGE_ADD_NODES));
190-
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MERGE_ADD_DIM));
189+
NnUint zPipeIndex = netBuilder->addPipe("Z", size2D(F_32, N_BATCHES, MERGE_ADD_F32_DIM * MERGE_ADD_F32_NODES));
190+
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MERGE_ADD_F32_DIM));
191191
segmentBuilder->addOp(OP_MERGE_ADD, "mergeAdd", 0,
192192
pointerBatchConfig(SRC_PIPE, zPipeIndex),
193193
pointerBatchConfig(SRC_PIPE, xPipeIndex),
@@ -201,9 +201,9 @@ void testMergeAdd_F32_F32() {
201201
float *zPipe = (float *)execution->pipes[0];
202202
float *xPipe = (float *)execution->pipes[1];
203203
for (NnUint b = 0; b < N_BATCHES; b++) {
204-
for (NnUint n = 0; n < MERGE_ADD_NODES; n++) {
205-
for (NnUint i = 0; i < MERGE_ADD_DIM; i++)
206-
zPipe[b * MERGE_ADD_NODES * MERGE_ADD_DIM + n * MERGE_ADD_DIM + i] = (float)(b + 1);
204+
for (NnUint n = 0; n < MERGE_ADD_F32_NODES; n++) {
205+
for (NnUint i = 0; i < MERGE_ADD_F32_DIM; i++)
206+
zPipe[b * MERGE_ADD_F32_NODES * MERGE_ADD_F32_DIM + n * MERGE_ADD_F32_DIM + i] = (float)(b + 1);
207207
}
208208
}
209209

@@ -212,15 +212,58 @@ void testMergeAdd_F32_F32() {
212212

213213
// assert
214214
for (NnUint b = 0; b < N_BATCHES; b++) {
215-
for (NnUint i = 0; i < MERGE_ADD_DIM; i++) {
216-
NnUint j = b * MERGE_ADD_DIM + i;
215+
for (NnUint i = 0; i < MERGE_ADD_F32_DIM; i++) {
216+
NnUint j = b * MERGE_ADD_F32_DIM + i;
217217
assertFloat(j, xPipe[j], (float)(2 * b + 2), 0.00001f);
218218
}
219219
}
220220
printOk("testMergeAdd_F32_F32");
221221
});
222222
}
223223

224+
static void testMergeAdd_Q80_F32() {
225+
#define MERGE_ADD_Q80_NODES 2
226+
#define MERGE_ADD_Q80_DIM 64
227+
execute(
228+
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
229+
const NnUint zPipeIndex = netBuilder->addPipe("Z", size2D(F_Q80, N_BATCHES, MERGE_ADD_Q80_DIM * MERGE_ADD_Q80_NODES));
230+
const NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MERGE_ADD_Q80_DIM));
231+
segmentBuilder->addOp(OP_MERGE_ADD, "mergeAdd", 0,
232+
pointerBatchConfig(SRC_PIPE, zPipeIndex),
233+
pointerBatchConfig(SRC_PIPE, xPipeIndex),
234+
size0(),
235+
NnMergeAddOpCodeConfig{});
236+
},
237+
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
238+
// arrange
239+
execution->setBatchSize(N_BATCHES);
240+
241+
float z[N_BATCHES * MERGE_ADD_Q80_DIM * MERGE_ADD_Q80_NODES];
242+
for (NnUint b = 0; b < N_BATCHES; b++) {
243+
for (NnUint n = 0; n < MERGE_ADD_Q80_NODES; n++) {
244+
for (NnUint i = 0; i < MERGE_ADD_Q80_DIM; i++)
245+
z[b * MERGE_ADD_Q80_NODES * MERGE_ADD_Q80_DIM + n * MERGE_ADD_Q80_DIM + i] = (float)(b + 1);
246+
}
247+
}
248+
249+
NnBlockQ80 *zPipe = (NnBlockQ80 *)execution->pipes[0];
250+
const float *xPipe = (float *)execution->pipes[1];
251+
quantizeF32toQ80(z, zPipe, N_BATCHES * MERGE_ADD_Q80_DIM * MERGE_ADD_Q80_NODES, 1, 0);
252+
253+
// act
254+
executor->forward();
255+
256+
// assert
257+
for (NnUint b = 0; b < N_BATCHES; b++) {
258+
for (NnUint i = 0; i < MERGE_ADD_Q80_DIM; i++) {
259+
NnUint j = b * MERGE_ADD_Q80_DIM + i;
260+
assertFloat(j, xPipe[j], (float)(2 * b + 2), 0.00001f);
261+
}
262+
}
263+
printOk("testMergeAdd_Q80_F32");
264+
});
265+
}
266+
224267
void testEmbedding_F32_F32() {
225268
#define EMBEDDING_DIM 16
226269
#define EMBEDDING_LEN 8
@@ -528,6 +571,7 @@ int main() {
528571
testSilu_F32_F32();
529572
testMul_F32_F32();
530573
testMergeAdd_F32_F32();
574+
testMergeAdd_Q80_F32();
531575
testEmbedding_F32_F32();
532576
testShift_F32_F32();
533577
testCast_F32_F32();

src/nn/nn-vulkan.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ NnDeviceSegment *NnVulkanDevice::createSegment(NnUint segmentIndex) {
375375
static const char *getShaderFileName(const NnOpCode opCode, const NnOpQuantType quantType) {
376376
if (opCode == OP_MERGE_ADD) {
377377
if (quantType == F32_F32_F32) return "merge-add-forward-f32-f32.spv";
378+
if (quantType == Q80_Q80_F32) return "merge-add-forward-q80-f32.spv";
378379
}
379380
if (opCode == OP_EMBEDDING) {
380381
if (quantType == F32_F32_F32) return "embedding-forward-f32-f32.spv";
@@ -493,6 +494,7 @@ static std::vector<uint32_t> readShader(const char *fileName) {
493494
constexpr size_t maxSize = 16384;
494495
uint32_t chunk[maxSize];
495496
size_t bytesRead = fread(chunk, 1, maxSize, file);
497+
assert(bytesRead < maxSize); // Check if the file is too large
496498
if (bytesRead > 0)
497499
code.insert(code.end(), chunk, chunk + bytesRead);
498500
if (ferror(file)) {

src/nn/vulkan/cast-forward-f32-q80.comp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#version 450
22

3+
#extension GL_EXT_control_flow_attributes : enable
34
#extension GL_EXT_shader_16bit_storage : enable
45
#extension GL_EXT_shader_explicit_arithmetic_types : enable
56

@@ -70,7 +71,7 @@ void main() {
7071

7172
y[yiOffset].d = float16_t(d);
7273

73-
for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
74+
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
7475
const float v = x[xiOffset + j];
7576
y[yiOffset].qs[j] = int8_t(clamp(round(v * id), -127.0, 127.0));
7677
}

src/nn/vulkan/matmul-forward-f32-f32-f32.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#define N_THREADS 128
44

55
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
66

src/nn/vulkan/merge-add-forward-f32-f32.comp

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#version 450
22

3-
#extension GL_EXT_control_flow_attributes : enable
4-
53
#define N_THREADS 64
64

75
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#version 450
2+
3+
#extension GL_EXT_control_flow_attributes : enable
4+
#extension GL_EXT_shader_16bit_storage : enable
5+
#extension GL_EXT_shader_explicit_arithmetic_types : enable
6+
7+
#define Q80_BLOCK_SIZE 32
8+
#define N_THREADS 64
9+
10+
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
11+
12+
struct BatchInfo {
13+
uint inputOffset; // number of Q80 blocks
14+
uint inputSizeX; // number of Q80 blocks
15+
uint outputOffset;
16+
uint outputSizeX;
17+
};
18+
19+
struct BlockQ80 {
20+
float16_t d;
21+
int8_t qs[Q80_BLOCK_SIZE];
22+
};
23+
24+
layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; };
25+
layout(binding = 1) buffer outputBuffer { float y[]; };
26+
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
27+
28+
shared uint sharedXStart;
29+
shared uint sharedXEnd;
30+
shared uint sharedNParts;
31+
shared uint sharedXJump;
32+
shared uint sharedXOffset;
33+
shared uint sharedYOffset;
34+
35+
void main() {
36+
const uint threadIndex = gl_LocalInvocationID.x;
37+
38+
if (threadIndex == 0) {
39+
const uint nWorkGroups = gl_NumWorkGroups.z;
40+
const uint batchIndex = gl_WorkGroupID.y;
41+
const uint workGroupIndex = gl_WorkGroupID.z;
42+
43+
const BatchInfo info = infos[batchIndex];
44+
const uint xJump = info.outputSizeX / Q80_BLOCK_SIZE;
45+
const uint nParts = info.inputSizeX / xJump;
46+
const uint xSlice = xJump / nWorkGroups;
47+
const uint xRest = xJump % nWorkGroups;
48+
49+
sharedXStart = workGroupIndex * xSlice + (workGroupIndex < xRest ? workGroupIndex : xRest);
50+
sharedXEnd = sharedXStart + xSlice + (workGroupIndex < xRest ? 1 : 0);
51+
sharedNParts = nParts;
52+
sharedXJump = xJump;
53+
sharedXOffset = info.inputOffset;
54+
sharedYOffset = info.outputOffset;
55+
}
56+
57+
barrier();
58+
memoryBarrierShared();
59+
60+
const uint xStart = sharedXStart + threadIndex;
61+
const uint xEnd = sharedXEnd;
62+
const uint xJump = sharedXJump;
63+
const uint nParts = sharedNParts;
64+
const uint xOffset = sharedXOffset;
65+
const uint yOffset = sharedYOffset;
66+
float16_t sums[Q80_BLOCK_SIZE];
67+
68+
for (uint i = xStart; i < xEnd; i += N_THREADS) {
69+
const uint xiOffset = xOffset + i;
70+
const uint yiOffset = yOffset + i * Q80_BLOCK_SIZE;
71+
72+
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
73+
sums[k] = float16_t(0.0);
74+
}
75+
for (uint n = 0; n < nParts; n++) {
76+
const BlockQ80 b = x[xiOffset + n * xJump];
77+
const float16_t d = b.d;
78+
79+
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
80+
sums[k] += float16_t(b.qs[k]) * d;
81+
}
82+
}
83+
84+
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
85+
y[yiOffset + k] += float(sums[k]);
86+
}
87+
}
88+
}

0 commit comments

Comments
 (0)