@@ -182,12 +182,12 @@ void testMul_F32_F32() {
182
182
}
183
183
184
184
void testMergeAdd_F32_F32 () {
185
- #define MERGE_ADD_NODES 2
186
- #define MERGE_ADD_DIM 64
185
+ #define MERGE_ADD_F32_NODES 2
186
+ #define MERGE_ADD_F32_DIM 64
187
187
execute (
188
188
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
189
- NnUint zPipeIndex = netBuilder->addPipe (" Z" , size2D (F_32, N_BATCHES, MERGE_ADD_DIM * MERGE_ADD_NODES ));
190
- NnUint xPipeIndex = netBuilder->addPipe (" X" , size2D (F_32, N_BATCHES, MERGE_ADD_DIM ));
189
+ NnUint zPipeIndex = netBuilder->addPipe (" Z" , size2D (F_32, N_BATCHES, MERGE_ADD_F32_DIM * MERGE_ADD_F32_NODES ));
190
+ NnUint xPipeIndex = netBuilder->addPipe (" X" , size2D (F_32, N_BATCHES, MERGE_ADD_F32_DIM ));
191
191
segmentBuilder->addOp (OP_MERGE_ADD, " mergeAdd" , 0 ,
192
192
pointerBatchConfig (SRC_PIPE, zPipeIndex),
193
193
pointerBatchConfig (SRC_PIPE, xPipeIndex),
@@ -201,9 +201,9 @@ void testMergeAdd_F32_F32() {
201
201
float *zPipe = (float *)execution->pipes [0 ];
202
202
float *xPipe = (float *)execution->pipes [1 ];
203
203
for (NnUint b = 0 ; b < N_BATCHES; b++) {
204
- for (NnUint n = 0 ; n < MERGE_ADD_NODES ; n++) {
205
- for (NnUint i = 0 ; i < MERGE_ADD_DIM ; i++)
206
- zPipe[b * MERGE_ADD_NODES * MERGE_ADD_DIM + n * MERGE_ADD_DIM + i] = (float )(b + 1 );
204
+ for (NnUint n = 0 ; n < MERGE_ADD_F32_NODES ; n++) {
205
+ for (NnUint i = 0 ; i < MERGE_ADD_F32_DIM ; i++)
206
+ zPipe[b * MERGE_ADD_F32_NODES * MERGE_ADD_F32_DIM + n * MERGE_ADD_F32_DIM + i] = (float )(b + 1 );
207
207
}
208
208
}
209
209
@@ -212,15 +212,58 @@ void testMergeAdd_F32_F32() {
212
212
213
213
// assert
214
214
for (NnUint b = 0 ; b < N_BATCHES; b++) {
215
- for (NnUint i = 0 ; i < MERGE_ADD_DIM ; i++) {
216
- NnUint j = b * MERGE_ADD_DIM + i;
215
+ for (NnUint i = 0 ; i < MERGE_ADD_F32_DIM ; i++) {
216
+ NnUint j = b * MERGE_ADD_F32_DIM + i;
217
217
assertFloat (j, xPipe[j], (float )(2 * b + 2 ), 0 .00001f );
218
218
}
219
219
}
220
220
printOk (" testMergeAdd_F32_F32" );
221
221
});
222
222
}
223
223
224
+ static void testMergeAdd_Q80_F32 () {
225
+ #define MERGE_ADD_Q80_NODES 2
226
+ #define MERGE_ADD_Q80_DIM 64
227
+ execute (
228
+ [](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
229
+ const NnUint zPipeIndex = netBuilder->addPipe (" Z" , size2D (F_Q80, N_BATCHES, MERGE_ADD_Q80_DIM * MERGE_ADD_Q80_NODES));
230
+ const NnUint xPipeIndex = netBuilder->addPipe (" X" , size2D (F_32, N_BATCHES, MERGE_ADD_Q80_DIM));
231
+ segmentBuilder->addOp (OP_MERGE_ADD, " mergeAdd" , 0 ,
232
+ pointerBatchConfig (SRC_PIPE, zPipeIndex),
233
+ pointerBatchConfig (SRC_PIPE, xPipeIndex),
234
+ size0 (),
235
+ NnMergeAddOpCodeConfig{});
236
+ },
237
+ [](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
238
+ // arrange
239
+ execution->setBatchSize (N_BATCHES);
240
+
241
+ float z[N_BATCHES * MERGE_ADD_Q80_DIM * MERGE_ADD_Q80_NODES];
242
+ for (NnUint b = 0 ; b < N_BATCHES; b++) {
243
+ for (NnUint n = 0 ; n < MERGE_ADD_Q80_NODES; n++) {
244
+ for (NnUint i = 0 ; i < MERGE_ADD_Q80_DIM; i++)
245
+ z[b * MERGE_ADD_Q80_NODES * MERGE_ADD_Q80_DIM + n * MERGE_ADD_Q80_DIM + i] = (float )(b + 1 );
246
+ }
247
+ }
248
+
249
+ NnBlockQ80 *zPipe = (NnBlockQ80 *)execution->pipes [0 ];
250
+ const float *xPipe = (float *)execution->pipes [1 ];
251
+ quantizeF32toQ80 (z, zPipe, N_BATCHES * MERGE_ADD_Q80_DIM * MERGE_ADD_Q80_NODES, 1 , 0 );
252
+
253
+ // act
254
+ executor->forward ();
255
+
256
+ // assert
257
+ for (NnUint b = 0 ; b < N_BATCHES; b++) {
258
+ for (NnUint i = 0 ; i < MERGE_ADD_Q80_DIM; i++) {
259
+ NnUint j = b * MERGE_ADD_Q80_DIM + i;
260
+ assertFloat (j, xPipe[j], (float )(2 * b + 2 ), 0 .00001f );
261
+ }
262
+ }
263
+ printOk (" testMergeAdd_Q80_F32" );
264
+ });
265
+ }
266
+
224
267
void testEmbedding_F32_F32 () {
225
268
#define EMBEDDING_DIM 16
226
269
#define EMBEDDING_LEN 8
@@ -528,6 +571,7 @@ int main() {
528
571
testSilu_F32_F32 ();
529
572
testMul_F32_F32 ();
530
573
testMergeAdd_F32_F32 ();
574
+ testMergeAdd_Q80_F32 ();
531
575
testEmbedding_F32_F32 ();
532
576
testShift_F32_F32 ();
533
577
testCast_F32_F32 ();
0 commit comments