Skip to content

Commit 2b8c554

Browse files
author
Viviane Potocnik
committed
[sw] FCL: explicitely declare struct fields
1 parent 2fdb93e commit 2b8c554

File tree

1 file changed

+53
-44
lines changed

1 file changed

+53
-44
lines changed

sw/dnn/fused_concat_linear/src/fused_concat_linear.h

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// SPDX-License-Identifier: Apache-2.0
44
//
55
// Luca Colagrande <colluca@iis.ee.ethz.ch>
6+
// Viviane Potocnik <vivianep@iis.ee.ethz.ch>
67

78
#include "snrt.h"
89

@@ -47,28 +48,32 @@ static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) {
4748
uint32_t k = l.input_shape[1] * l.num_inputs;
4849
uint32_t n = l.output_shape[1];
4950

50-
gemm_args_t gemm_args = {1.0,
51-
l.dtype,
52-
0,
53-
1,
54-
0,
55-
snrt_cluster_num(),
56-
1,
57-
1,
58-
0,
59-
1,
60-
1,
61-
0,
62-
0,
63-
m,
64-
n,
65-
k,
66-
l.concat_output,
67-
l.weights,
68-
0,
69-
l.linear_output,
70-
l.gemm_implementation};
71-
gemm(&gemm_args);
51+
gemm_args_t gemm_args;
52+
gemm_args_t *local_args = (gemm_args_t *)&gemm_args;
53+
54+
local_args->alpha = 1.0;
55+
local_args->prec = l.dtype;
56+
local_args->setup_ssr = 0;
57+
local_args->parallelize_m = 1;
58+
local_args->parallelize_k = 0;
59+
local_args->m_tiles = snrt_cluster_core_num();
60+
local_args->n_tiles = 1;
61+
local_args->k_tiles = 1;
62+
local_args->load_a = 0;
63+
local_args->load_b = 1;
64+
local_args->load_c = 1;
65+
local_args->transa = 0;
66+
local_args->transb = 0;
67+
local_args->M = m;
68+
local_args->N = n;
69+
local_args->K = k;
70+
local_args->a = l.concat_output;
71+
local_args->b = l.weights;
72+
local_args->beta = 0;
73+
local_args->c = l.linear_output;
74+
local_args->gemm_fp = l.gemm_implementation;
75+
76+
gemm(&local_args);
7277

7378
snrt_global_barrier();
7479

@@ -92,28 +97,32 @@ static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t l) {
9297
}
9398
snrt_cluster_hw_barrier();
9499

95-
gemm_args_t gemm_args = {1.0,
96-
l.dtype,
97-
0,
98-
0,
99-
1,
100-
1,
101-
1,
102-
l.num_inputs,
103-
0,
104-
1,
105-
1,
106-
0,
107-
0,
108-
m,
109-
n,
110-
concat_k,
111-
a,
112-
l.weights,
113-
0,
114-
l.linear_output,
115-
l.gemm_implementation};
116-
gemm(&gemm_args);
100+
gemm_args_t gemm_args;
101+
gemm_args_t *local_args = (gemm_args_t *)&gemm_args;
102+
103+
local_args->alpha = 1.0;
104+
local_args->prec = l.dtype;
105+
local_args->setup_ssr = 0;
106+
local_args->parallelize_m = 0;
107+
local_args->parallelize_k = 1;
108+
local_args->m_tiles = 1;
109+
local_args->n_tiles = 1;
110+
local_args->k_tiles = l.num_inputs;
111+
local_args->load_a = 0;
112+
local_args->load_b = 1;
113+
local_args->load_c = 1;
114+
local_args->transa = 0;
115+
local_args->transb = 0;
116+
local_args->M = m;
117+
local_args->N = n;
118+
local_args->K = concat_k;
119+
local_args->a = a;
120+
local_args->b = l.weights;
121+
local_args->beta = 0;
122+
local_args->c = l.linear_output;
123+
local_args->gemm_fp = l.gemm_implementation;
124+
125+
gemm(&local_args);
117126

118127
snrt_global_barrier();
119128

0 commit comments

Comments
 (0)