Skip to content

Commit 9a024dc

Browse files
timber77colluca
andauthored
sw: Fix GEMM bug when N<unroll (#222)
--------- Co-authored-by: Luca Colagrande <luca.colagrande3@gmail.com>
1 parent 1d2684a commit 9a024dc

File tree

5 files changed

+141
-116
lines changed

5 files changed

+141
-116
lines changed

sw/blas/gemm/src/gemm_fp16.h

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -116,28 +116,33 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA,
116116
// for maximum utilization
117117
const uint32_t unroll = 8;
118118

119-
// SSR strides and bounds only have to be configured
120-
// once in the beginning
121-
if (setup_SSR) {
122-
uint32_t ssr0_b[4] = {unroll, K / 4, N / unroll, M};
123-
uint32_t ssr0_i[4] = {0, sizeof(__fp16) * 4, 0, sizeof(__fp16) * ldA};
124-
125-
uint32_t ssr1_b[4] = {unroll, K / 4, N / unroll, M};
126-
uint32_t ssr1_i[4] = {sizeof(__fp16) * ldB, sizeof(__fp16) * 4,
127-
sizeof(__fp16) * unroll * ldB, 0};
128-
129-
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
130-
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
131-
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
132-
133-
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
134-
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]);
135-
}
119+
// Don't enable the SSRs if the stream won't be used
120+
if (N >= unroll) {
121+
// SSR strides and bounds only have to be configured
122+
// once in the beginning
123+
if (setup_SSR) {
124+
uint32_t ssr0_b[4] = {unroll, K / 4, N / unroll, M};
125+
uint32_t ssr0_i[4] = {0, sizeof(__fp16) * 4, 0,
126+
sizeof(__fp16) * ldA};
127+
128+
uint32_t ssr1_b[4] = {unroll, K / 4, N / unroll, M};
129+
uint32_t ssr1_i[4] = {sizeof(__fp16) * ldB, sizeof(__fp16) * 4,
130+
sizeof(__fp16) * unroll * ldB, 0};
131+
132+
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
133+
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
134+
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
135+
136+
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
137+
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
138+
ssr1_i[3]);
139+
}
136140

137-
// SSR start address need to be configured each time
138-
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
139-
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
140-
snrt_ssr_enable();
141+
// SSR start address need to be configured each time
142+
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
143+
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
144+
snrt_ssr_enable();
145+
}
141146

142147
// Kernel progresses by 4 values each step
143148
const uint32_t n_frep = K / 4 - 1;
@@ -303,29 +308,34 @@ void gemm_fp16_opt_ex(uint32_t M, uint32_t N, uint32_t K, void* A_p,
303308
// for maximum utilization
304309
const uint32_t unroll = 8;
305310

306-
// SSR strides and bounds only have to be configured
307-
// once in the beginning
308-
if (setup_SSR) {
309-
uint32_t ssr0_b[4] = {unroll, K / 4, N / unroll, M};
310-
uint32_t ssr0_i[4] = {0, sizeof(__fp16) * 4, 0, sizeof(__fp16) * ldA};
311-
312-
uint32_t ssr1_b[4] = {unroll, K / 4, N / unroll, M};
313-
uint32_t ssr1_i[4] = {sizeof(__fp16) * ldB, sizeof(__fp16) * 4,
314-
sizeof(__fp16) * unroll * ldB, 0};
315-
316-
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
317-
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
318-
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
311+
// Don't enable the SSRs if the stream won't be used
312+
if (N >= unroll) {
313+
// SSR strides and bounds only have to be configured
314+
// once in the beginning
315+
if (setup_SSR) {
316+
uint32_t ssr0_b[4] = {unroll, K / 4, N / unroll, M};
317+
uint32_t ssr0_i[4] = {0, sizeof(__fp16) * 4, 0,
318+
sizeof(__fp16) * ldA};
319+
320+
uint32_t ssr1_b[4] = {unroll, K / 4, N / unroll, M};
321+
uint32_t ssr1_i[4] = {sizeof(__fp16) * ldB, sizeof(__fp16) * 4,
322+
sizeof(__fp16) * unroll * ldB, 0};
323+
324+
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
325+
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
326+
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
327+
328+
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
329+
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
330+
ssr1_i[3]);
331+
}
319332

320-
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
321-
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]);
333+
// SSR start address need to be configured each time
334+
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
335+
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
336+
snrt_ssr_enable();
322337
}
323338

324-
// SSR start address need to be configured each time
325-
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
326-
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
327-
snrt_ssr_enable();
328-
329339
// Kernel progresses by 4 values each step
330340
const uint32_t n_frep = K / 4 - 1;
331341

sw/blas/gemm/src/gemm_fp32.h

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -216,28 +216,32 @@ void gemm_fp32_opt(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA,
216216
// for maximum utilization
217217
const uint32_t unroll = 8;
218218

219-
// SSR strides and bounds only have to be configured
220-
// once in the beginning
221-
if (setup_SSR) {
222-
uint32_t ssr0_b[4] = {unroll, K / 2, N / unroll, M};
223-
uint32_t ssr0_i[4] = {0, sizeof(float) * 2, 0, sizeof(float) * ldA};
219+
// Don't enable the SSRs if the stream won't be used
220+
if (N >= unroll) {
221+
// SSR strides and bounds only have to be configured
222+
// once in the beginning
223+
if (setup_SSR) {
224+
uint32_t ssr0_b[4] = {unroll, K / 2, N / unroll, M};
225+
uint32_t ssr0_i[4] = {0, sizeof(float) * 2, 0, sizeof(float) * ldA};
224226

225-
uint32_t ssr1_b[4] = {unroll, K / 2, N / unroll, M};
226-
uint32_t ssr1_i[4] = {sizeof(float) * ldB, sizeof(float) * 2,
227-
sizeof(float) * unroll * ldB, 0};
227+
uint32_t ssr1_b[4] = {unroll, K / 2, N / unroll, M};
228+
uint32_t ssr1_i[4] = {sizeof(float) * ldB, sizeof(float) * 2,
229+
sizeof(float) * unroll * ldB, 0};
228230

229-
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
230-
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
231-
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
231+
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
232+
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
233+
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
232234

233-
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
234-
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]);
235-
}
235+
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
236+
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
237+
ssr1_i[3]);
238+
}
236239

237-
// SSR start address need to be configured each time
238-
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
239-
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
240-
snrt_ssr_enable();
240+
// SSR start address need to be configured each time
241+
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
242+
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
243+
snrt_ssr_enable();
244+
}
241245

242246
// Kernel progresses by 2 values each step
243247
const uint32_t n_frep = K / 2 - 1;

sw/blas/gemm/src/gemm_fp64.h

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -70,48 +70,51 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA,
7070
// for maximum utilization
7171
const uint32_t unroll = 8;
7272

73-
// SSR strides and bounds only have to be configured
74-
// once in the beginning
75-
if (setup_SSR) {
76-
// First matrix is stored in transposed format
77-
if (ta) {
78-
const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M};
79-
const uint32_t ssr0_i[4] = {0, 8 * ldA, 0, 8 * 8};
80-
81-
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
82-
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
83-
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
84-
} else {
85-
const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M};
86-
const uint32_t ssr0_i[4] = {0, 8, 0, 8 * ldA};
87-
88-
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
89-
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
90-
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
91-
}
73+
// Don't enable the SSRs if the stream won't be used
74+
if (N >= unroll) {
75+
// SSR strides and bounds only have to be configured
76+
// once in the beginning
77+
if (setup_SSR) {
78+
// First matrix is stored in transposed format
79+
if (ta) {
80+
const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M};
81+
const uint32_t ssr0_i[4] = {0, 8 * ldA, 0, 8 * 8};
82+
83+
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
84+
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
85+
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
86+
} else {
87+
const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M};
88+
const uint32_t ssr0_i[4] = {0, 8, 0, 8 * ldA};
89+
90+
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
91+
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
92+
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
93+
}
94+
95+
// Second matrix is stored in transposed format
96+
if (tb) {
97+
const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M};
98+
const uint32_t ssr1_i[4] = {8 * ldB, 8, 8 * ldB * unroll, 0};
9299

93-
// Second matrix is stored in transposed format
94-
if (tb) {
95-
const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M};
96-
const uint32_t ssr1_i[4] = {8 * ldB, 8, 8 * ldB * unroll, 0};
97-
98-
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
99-
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
100-
ssr1_i[3]);
101-
} else {
102-
const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M};
103-
const uint32_t ssr1_i[4] = {8, 8 * ldB, 8 * unroll, 0};
104-
105-
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
106-
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
107-
ssr1_i[3]);
100+
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
101+
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
102+
ssr1_i[3]);
103+
} else {
104+
const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M};
105+
const uint32_t ssr1_i[4] = {8, 8 * ldB, 8 * unroll, 0};
106+
107+
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
108+
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
109+
ssr1_i[3]);
110+
}
108111
}
109-
}
110112

111-
// SSR start address need to be configured each time
112-
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
113-
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
114-
snrt_ssr_enable();
113+
// SSR start address need to be configured each time
114+
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
115+
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
116+
snrt_ssr_enable();
117+
}
115118

116119
for (uint32_t m = 0; m < M; m++) {
117120
uint32_t n = 0;

sw/blas/gemm/src/gemm_fp8.h

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,28 +126,32 @@ void gemm_fp8_opt_ex(uint32_t M, uint32_t N, uint32_t K, void* A_p,
126126
// for maximum utilization
127127
const uint32_t unroll = 8;
128128

129-
// SSR strides and bounds only have to be configured
130-
// once in the beginning
131-
if (setup_SSR) {
132-
uint32_t ssr0_b[4] = {unroll, K / 8, N / unroll, M};
133-
uint32_t ssr0_i[4] = {0, sizeof(char) * 8, 0, sizeof(char) * ldA};
129+
// Don't enable the SSRs if the stream won't be used
130+
if (N >= unroll) {
131+
// SSR strides and bounds only have to be configured
132+
// once in the beginning
133+
if (setup_SSR) {
134+
uint32_t ssr0_b[4] = {unroll, K / 8, N / unroll, M};
135+
uint32_t ssr0_i[4] = {0, sizeof(char) * 8, 0, sizeof(char) * ldA};
134136

135-
uint32_t ssr1_b[4] = {unroll, K / 8, N / unroll, M};
136-
uint32_t ssr1_i[4] = {sizeof(char) * ldB, sizeof(char) * 8,
137-
sizeof(char) * unroll * ldB, 0};
137+
uint32_t ssr1_b[4] = {unroll, K / 8, N / unroll, M};
138+
uint32_t ssr1_i[4] = {sizeof(char) * ldB, sizeof(char) * 8,
139+
sizeof(char) * unroll * ldB, 0};
138140

139-
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
140-
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
141-
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
141+
snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3],
142+
ssr0_i[1], ssr0_i[2], ssr0_i[3]);
143+
snrt_ssr_repeat(SNRT_SSR_DM0, unroll);
142144

143-
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
144-
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]);
145-
}
145+
snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2],
146+
ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2],
147+
ssr1_i[3]);
148+
}
146149

147-
// SSR start address need to be configured each time
148-
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
149-
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
150-
snrt_ssr_enable();
150+
// SSR start address need to be configured each time
151+
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A);
152+
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B);
153+
snrt_ssr_enable();
154+
}
151155

152156
// Kernel progresses by 8 values each step
153157
const uint32_t n_frep = K / 8 - 1;

sw/snRuntime/src/ssr.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
* The convenience functions provided in this file can be used to set up such
2626
* access patterns. The function argument names reflect the variable names
2727
* presented in these sample code snippets.
28+
*
29+
* Note: The exact number of elements configured in an (I)SSR stream must be
30+
* consumed. Failure to comply with this requirement will result in undefined
31+
* behaviour.
2832
*/
2933

3034
#pragma once

0 commit comments

Comments
 (0)