@@ -72,8 +72,10 @@ void testRmsNorm_F32_F32_F32() {
72
72
float *xPipe = (float *)execution->pipes [0 ];
73
73
for (NnUint b = 0 ; b < batchSize; b++) {
74
74
float *xBatchPipe = &xPipe[b * RMS_NORM_DIM];
75
- for (NnUint i = 0 ; i < RMS_NORM_DIM; i++)
76
- xBatchPipe[i] = (float )(RMS_NORM_DIM - i) / (float )(RMS_NORM_DIM / 2 );
75
+ for (NnUint i = 0 ; i < RMS_NORM_DIM; i++) {
76
+ float u = (float )(RMS_NORM_DIM - i + b) / (float )(RMS_NORM_DIM / 2 );
77
+ xBatchPipe[i] = u;
78
+ }
77
79
}
78
80
79
81
// act
@@ -83,22 +85,20 @@ void testRmsNorm_F32_F32_F32() {
83
85
float invRmsBuffer[N_BATCHES];
84
86
device->data ->buffers [0 ].get ()->read ((NnByte *)invRmsBuffer);
85
87
88
+ float expectedS[N_BATCHES];
89
+ expectedS[0 ] = 0 .863493f ;
90
+ expectedS[1 ] = 0 .858468f ;
91
+
86
92
for (NnUint b = 0 ; b < batchSize; b++) {
87
93
float *xBatchPipe = &xPipe[b * RMS_NORM_DIM];
88
94
89
- float t = 0 .000001f ;
90
- assertFloat (b, invRmsBuffer[b], 0 .863493f , t);
91
- assertFloat (0 , xBatchPipe[0 ], 0 .001687f , t);
92
- assertFloat (1 , xBatchPipe[1 ], 0 .008400f , t);
93
- assertFloat (2 , xBatchPipe[2 ], 0 .015060f , t);
94
- assertFloat (35 , xBatchPipe[35 ], 0 .205286f , t);
95
- assertFloat (36 , xBatchPipe[36 ], 0 .210155f , t);
96
- assertFloat (119 , xBatchPipe[119 ], 0 .430514f , t);
97
- assertFloat (123 , xBatchPipe[123 ], 0 .431964f , t);
98
- assertFloat (234 , xBatchPipe[234 ], 0 .135804f , t);
99
- assertFloat (242 , xBatchPipe[242 ], 0 .089372f , t);
100
- assertFloat (249 , xBatchPipe[249 ], 0 .045977f , t);
101
- assertFloat (255 , xBatchPipe[255 ], 0 .006726f , t);
95
+ const float t = 0 .000001f ;
96
+ const float s = expectedS[b];
97
+ assertFloat (b, invRmsBuffer[b], s, t);
98
+ for (NnUint i = 0 ; i < RMS_NORM_DIM; i++) {
99
+ float u = (float )(RMS_NORM_DIM - i + b) / (float )(RMS_NORM_DIM / 2 );
100
+ assertFloat (b * RMS_NORM_DIM + i, xBatchPipe[i], (u * s) * normWeight[i], t);
101
+ }
102
102
}
103
103
printOk (" testRmsNorm_F32_F32_F32" );
104
104
});
@@ -165,7 +165,7 @@ void testMul_F32_F32() {
165
165
float sBuffer [MUL_DIM * N_BATCHES];
166
166
for (NnUint i = 0 ; i < MUL_DIM * N_BATCHES; i++) {
167
167
xPipe[i] = (float )i;
168
- sBuffer [i] = cosf (( float )i) ;
168
+ sBuffer [i] = (i % 8 ) / 10 . 0f ;
169
169
}
170
170
171
171
device->data ->buffers [0 ].get ()->write ((NnByte *)sBuffer );
@@ -175,7 +175,7 @@ void testMul_F32_F32() {
175
175
176
176
// assert
177
177
for (NnUint i = 0 ; i < MUL_DIM * N_BATCHES; i++)
178
- assertFloat (i, xPipe[i], i * cosf (( float )i) , 0 .00001f );
178
+ assertFloat (i, xPipe[i], i * ((i % 8 ) / 10 . 0f ) , 0 .000001f );
179
179
printOk (" testMul_F32_F32" );
180
180
});
181
181
}
0 commit comments