Skip to content

Commit 8b52aa5

Browse files
author
Viviane Potocnik
committed
FA-2: revert latest minifloat chanegs due to missing TCDM dyn alloc
1 parent 05b35ed commit 8b52aa5

File tree

3 files changed

+90
-40
lines changed

3 files changed

+90
-40
lines changed

sw/dnn/flashattention_2/src/flashattention_2_fp16.h

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,35 @@ static inline void flashattention_2_fp16(flashattention_2_layer_t layer) {
4040
uint32_t m_i_size = B_r * sizeof(float);
4141
uint32_t m_i_prev_size = m_i_size;
4242
uint32_t l_i_size = B_r * sizeof(float);
43+
uint32_t shifted_exp_size = B_r * sizeof(float);
4344

4445
// allocate memory in TCDM
45-
__fp16 *Q_fa = snrt_l1_alloc_cluster_local(q_fa_size, sizeof(__fp16));
46-
__fp16 *K_fa = snrt_l1_alloc_cluster_local(k_fa_size, sizeof(__fp16));
47-
__fp16 *V_fa = snrt_l1_alloc_cluster_local(v_fa_size, sizeof(__fp16));
48-
__fp16 *S_fa = snrt_l1_alloc_cluster_local(s_fa_size, sizeof(__fp16));
49-
__fp16 *P_fa = snrt_l1_alloc_cluster_local(p_fa_size, sizeof(__fp16));
50-
__fp16 *O_fa = snrt_l1_alloc_cluster_local(o_fa_size, sizeof(__fp16));
51-
float *m_i = snrt_l1_alloc_cluster_local(m_i_size, sizeof(float));
52-
float *m_i_prev = snrt_l1_alloc_cluster_local(m_i_prev_size, sizeof(float));
53-
float *l_i = snrt_l1_alloc_cluster_local(l_i_size, sizeof(float));
54-
55-
// allocate space for V^t when using optimized kernels
46+
void *tcdm_ptr = (__fp16 *)snrt_l1_next();
47+
__fp16 *Q_fa = tcdm_ptr;
48+
tcdm_ptr += q_fa_size;
49+
__fp16 *K_fa = tcdm_ptr;
50+
tcdm_ptr += k_fa_size;
51+
__fp16 *V_fa = tcdm_ptr;
52+
tcdm_ptr += v_fa_size;
53+
__fp16 *S_fa = tcdm_ptr;
54+
tcdm_ptr += s_fa_size;
55+
__fp16 *P_fa = tcdm_ptr;
56+
tcdm_ptr += p_fa_size;
57+
__fp16 *O_fa = tcdm_ptr;
58+
tcdm_ptr += o_fa_size;
59+
float *m_i = tcdm_ptr;
60+
tcdm_ptr += m_i_size;
61+
float *m_i_prev = tcdm_ptr;
62+
tcdm_ptr += m_i_prev_size;
63+
float *l_i = tcdm_ptr;
64+
tcdm_ptr += l_i_size;
65+
66+
// Allocate space for V^t
5667
__fp16 *V_t;
57-
if (!baseline) V_t = snrt_l1_alloc_cluster_local(v_fa_size, sizeof(__fp16));
68+
if (!baseline) {
69+
V_t = tcdm_ptr;
70+
tcdm_ptr += B_c * d * sizeof(__fp16);
71+
}
5872

5973
float shifted_exp;
6074
float row_sum;
@@ -105,6 +119,7 @@ static inline void flashattention_2_fp16(flashattention_2_layer_t layer) {
105119

106120
// Iterate column blocks of K (corresponding to row blocks of V)
107121
for (int t_c = 0; t_c < T_c; t_c++) {
122+
108123
// DMA copy K column block (B_c, d) and V row block (B_c, d) to
109124
// TCDM. Both K and V are stored in (S, d) form in memory
110125
if (!snrt_is_compute_core()) {
@@ -199,7 +214,7 @@ static inline void flashattention_2_fp16(flashattention_2_layer_t layer) {
199214
beta = 0;
200215
else
201216
beta = 1;
202-
sc_st_gemm(dtype, 1, 0, 0, B_r, d, B_c, 1, P_fa, B_c, V_fa,
217+
sc_st_gemm(dtype, 0, 0, 0, B_r, d, B_c, 1, P_fa, B_c, V_fa,
203218
d, beta, O_fa, d, gemm_implementation);
204219
} else {
205220
// The SIMD-optimized GEMM kernel performs the A*B^t
@@ -217,7 +232,7 @@ static inline void flashattention_2_fp16(flashattention_2_layer_t layer) {
217232
beta = 0;
218233
else
219234
beta = 1;
220-
sc_st_gemm(dtype, 1, 0, 1, B_r, d, B_c, 1, P_fa, B_c, V_t,
235+
sc_st_gemm(dtype, 0, 0, 1, B_r, d, B_c, 1, P_fa, B_c, V_t,
221236
B_c, beta, O_fa, d, gemm_implementation);
222237
}
223238
} else {
@@ -231,6 +246,7 @@ static inline void flashattention_2_fp16(flashattention_2_layer_t layer) {
231246
snrt_mcycle();
232247
} // end of T_c loop
233248

249+
234250
// Rescaling for last t_c iteration
235251
// O_i = diag(l_i_Tc)^-1 * O_i
236252
if (snrt_is_compute_core()) {
@@ -240,6 +256,7 @@ static inline void flashattention_2_fp16(flashattention_2_layer_t layer) {
240256
}
241257
}
242258
}
259+
243260
snrt_fpu_fence();
244261
snrt_cluster_hw_barrier();
245262

sw/dnn/flashattention_2/src/flashattention_2_fp32.h

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,35 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
4040
uint32_t m_i_size = B_r * sizeof(float);
4141
uint32_t m_i_prev_size = m_i_size;
4242
uint32_t l_i_size = B_r * sizeof(float);
43+
uint32_t shifted_exp_size = B_r * sizeof(float);
4344

4445
// allocate memory in TCDM
45-
float *Q_fa = snrt_l1_alloc_cluster_local(q_fa_size, sizeof(float));
46-
float *K_fa = snrt_l1_alloc_cluster_local(k_fa_size, sizeof(float));
47-
float *V_fa = snrt_l1_alloc_cluster_local(v_fa_size, sizeof(float));
48-
float *S_fa = snrt_l1_alloc_cluster_local(s_fa_size, sizeof(float));
49-
float *P_fa = snrt_l1_alloc_cluster_local(p_fa_size, sizeof(float));
50-
float *O_fa = snrt_l1_alloc_cluster_local(o_fa_size, sizeof(float));
51-
float *m_i = snrt_l1_alloc_cluster_local(m_i_size, sizeof(float));
52-
float *m_i_prev = snrt_l1_alloc_cluster_local(m_i_prev_size, sizeof(float));
53-
float *l_i = snrt_l1_alloc_cluster_local(l_i_size, sizeof(float));
46+
void *tcdm_ptr = (float *)snrt_l1_next();
47+
float *Q_fa = tcdm_ptr;
48+
tcdm_ptr += q_fa_size;
49+
float *K_fa = tcdm_ptr;
50+
tcdm_ptr += k_fa_size;
51+
float *V_fa = tcdm_ptr;
52+
tcdm_ptr += v_fa_size;
53+
float *S_fa = tcdm_ptr;
54+
tcdm_ptr += s_fa_size;
55+
float *P_fa = tcdm_ptr;
56+
tcdm_ptr += p_fa_size;
57+
float *O_fa = tcdm_ptr;
58+
tcdm_ptr += o_fa_size;
59+
float *m_i = tcdm_ptr;
60+
tcdm_ptr += m_i_size;
61+
float *m_i_prev = tcdm_ptr;
62+
tcdm_ptr += m_i_prev_size;
63+
float *l_i = tcdm_ptr;
64+
tcdm_ptr += l_i_size;
5465

5566
// allocate space for V^t when using optimized kernels
5667
float *V_t;
57-
if (!baseline) V_t = snrt_l1_alloc_cluster_local(v_fa_size, sizeof(float));
68+
if (!baseline) {
69+
V_t = tcdm_ptr;
70+
tcdm_ptr += B_c * d * sizeof(float);
71+
}
5872

5973
float shifted_exp;
6074
float row_sum;
@@ -196,7 +210,7 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
196210
beta = 0;
197211
else
198212
beta = 1;
199-
sc_st_gemm(dtype, 1, 0, 0, B_r, d, B_c, 1, P_fa, B_c, V_fa,
213+
sc_st_gemm(dtype, 0, 0, 0, B_r, d, B_c, 1, P_fa, B_c, V_fa,
200214
d, beta, O_fa, d, gemm_implementation);
201215
} else {
202216
// The SIMD-optimized GEMM kernel performs the A*B^t
@@ -214,7 +228,7 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
214228
beta = 0;
215229
else
216230
beta = 1;
217-
sc_st_gemm(dtype, 1, 0, 1, B_r, d, B_c, 1, P_fa, B_c, V_t,
231+
sc_st_gemm(dtype, 0, 0, 1, B_r, d, B_c, 1, P_fa, B_c, V_t,
218232
B_c, beta, O_fa, d, gemm_implementation);
219233
}
220234
} else {

sw/dnn/flashattention_2/src/flashattention_2_fp8.h

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,35 @@ static inline void flashattention_2_fp8(flashattention_2_layer_t layer) {
6161
uint32_t m_i_size = B_r * sizeof(float);
6262
uint32_t m_i_prev_size = m_i_size;
6363
uint32_t l_i_size = B_r * sizeof(float);
64+
uint32_t shifted_exp_size = B_r * sizeof(float);
6465

6566
// allocate memory in TCDM
66-
char *Q_fa = snrt_l1_alloc_cluster_local(q_fa_size, sizeof(char));
67-
char *K_fa = snrt_l1_alloc_cluster_local(k_fa_size, sizeof(char));
68-
char *V_fa = snrt_l1_alloc_cluster_local(v_fa_size, sizeof(char));
69-
char *S_fa = snrt_l1_alloc_cluster_local(s_fa_size, sizeof(char));
70-
char *P_fa = snrt_l1_alloc_cluster_local(p_fa_size, sizeof(char));
71-
char *O_fa = snrt_l1_alloc_cluster_local(o_fa_size, sizeof(char));
72-
float *m_i = snrt_l1_alloc_cluster_local(m_i_size, sizeof(float));
73-
float *m_i_prev = snrt_l1_alloc_cluster_local(m_i_prev_size, sizeof(float));
74-
float *l_i = snrt_l1_alloc_cluster_local(l_i_size, sizeof(float));
75-
76-
// allocate space for V^t when using optimized kernels
67+
void *tcdm_ptr = (char *)snrt_l1_next();
68+
char *Q_fa = tcdm_ptr;
69+
tcdm_ptr += q_fa_size;
70+
char *K_fa = tcdm_ptr;
71+
tcdm_ptr += k_fa_size;
72+
char *V_fa = tcdm_ptr;
73+
tcdm_ptr += v_fa_size;
74+
char *S_fa = tcdm_ptr;
75+
tcdm_ptr += s_fa_size;
76+
char *P_fa = tcdm_ptr;
77+
tcdm_ptr += p_fa_size;
78+
char *O_fa = tcdm_ptr;
79+
tcdm_ptr += o_fa_size;
80+
float *m_i = tcdm_ptr;
81+
tcdm_ptr += m_i_size;
82+
float *m_i_prev = tcdm_ptr;
83+
tcdm_ptr += m_i_prev_size;
84+
float *l_i = tcdm_ptr;
85+
tcdm_ptr += l_i_size;
86+
87+
// Allocate space for V^t
7788
char *V_t;
78-
if (!baseline) V_t = snrt_l1_alloc_cluster_local(v_fa_size, sizeof(char));
89+
if(!baseline) {
90+
V_t = tcdm_ptr;
91+
tcdm_ptr += B_c * d * sizeof(char);
92+
}
7993

8094
float shifted_exp;
8195
float row_sum;
@@ -85,6 +99,7 @@ static inline void flashattention_2_fp8(flashattention_2_layer_t layer) {
8599
// Iterate row blocks of Q
86100
for (int t_r = 0; t_r < T_r; t_r++) {
87101
// DMA copy Q row block to TCDM
102+
uint32_t start_dma = snrt_mcycle();
88103
if (snrt_is_dm_core()) {
89104
snrt_dma_load_2d_tile(Q_fa, // dst
90105
Q_l3, // src
@@ -97,6 +112,8 @@ static inline void flashattention_2_fp8(flashattention_2_layer_t layer) {
97112
);
98113
snrt_dma_wait_all();
99114
}
115+
uint32_t end_dma = snrt_mcycle();
116+
100117
snrt_cluster_hw_barrier();
101118

102119
snrt_mcycle();
@@ -224,7 +241,7 @@ static inline void flashattention_2_fp8(flashattention_2_layer_t layer) {
224241
beta = 0;
225242
else
226243
beta = 1;
227-
sc_st_gemm(dtype, 1, 0, 0, B_r, d, B_c, 1, P_fa, B_c, V_fa,
244+
sc_st_gemm(dtype, 0, 0, 0, B_r, d, B_c, 1, P_fa, B_c, V_fa,
228245
d, beta, O_fa, d, gemm_implementation);
229246
} else {
230247
// The SIMD-optimized GEMM kernel performs the A*B^t
@@ -242,7 +259,7 @@ static inline void flashattention_2_fp8(flashattention_2_layer_t layer) {
242259
beta = 0;
243260
else
244261
beta = 1;
245-
sc_st_gemm(dtype, 1, 0, 1, B_r, d, B_c, 1, P_fa, B_c, V_t,
262+
sc_st_gemm(dtype, 0, 0, 1, B_r, d, B_c, 1, P_fa, B_c, V_t,
246263
B_c, beta, O_fa, d, gemm_implementation);
247264
}
248265
} else {
@@ -267,6 +284,7 @@ static inline void flashattention_2_fp8(flashattention_2_layer_t layer) {
267284
}
268285
}
269286
}
287+
270288
snrt_fpu_fence();
271289
snrt_cluster_hw_barrier();
272290

@@ -285,6 +303,7 @@ static inline void flashattention_2_fp8(flashattention_2_layer_t layer) {
285303
);
286304
snrt_dma_wait_all();
287305
}
306+
288307
snrt_cluster_hw_barrier();
289308

290309
snrt_mcycle();

0 commit comments

Comments
 (0)