3
3
// SPDX-License-Identifier: Apache-2.0
4
4
//
5
5
// Luca Colagrande <colluca@iis.ee.ethz.ch>
6
+ // Viviane Potocnik <vivianep@iis.ee.ethz.ch>
6
7
7
8
#include "snrt.h"
8
9
@@ -47,28 +48,32 @@ static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) {
47
48
uint32_t k = l .input_shape [1 ] * l .num_inputs ;
48
49
uint32_t n = l .output_shape [1 ];
49
50
50
- gemm_args_t gemm_args = {1.0 ,
51
- l .dtype ,
52
- 0 ,
53
- 1 ,
54
- 0 ,
55
- snrt_cluster_num (),
56
- 1 ,
57
- 1 ,
58
- 0 ,
59
- 1 ,
60
- 1 ,
61
- 0 ,
62
- 0 ,
63
- m ,
64
- n ,
65
- k ,
66
- l .concat_output ,
67
- l .weights ,
68
- 0 ,
69
- l .linear_output ,
70
- l .gemm_implementation };
71
- gemm (& gemm_args );
51
+ gemm_args_t gemm_args ;
52
+ gemm_args_t * local_args = (gemm_args_t * )& gemm_args ;
53
+
54
+ local_args -> alpha = 1.0 ;
55
+ local_args -> prec = l .dtype ;
56
+ local_args -> setup_ssr = 0 ;
57
+ local_args -> parallelize_m = 1 ;
58
+ local_args -> parallelize_k = 0 ;
59
+ local_args -> m_tiles = snrt_cluster_core_num ();
60
+ local_args -> n_tiles = 1 ;
61
+ local_args -> k_tiles = 1 ;
62
+ local_args -> load_a = 0 ;
63
+ local_args -> load_b = 1 ;
64
+ local_args -> load_c = 1 ;
65
+ local_args -> transa = 0 ;
66
+ local_args -> transb = 0 ;
67
+ local_args -> M = m ;
68
+ local_args -> N = n ;
69
+ local_args -> K = k ;
70
+ local_args -> a = l .concat_output ;
71
+ local_args -> b = l .weights ;
72
+ local_args -> beta = 0 ;
73
+ local_args -> c = l .linear_output ;
74
+ local_args -> gemm_fp = l .gemm_implementation ;
75
+
76
+ gemm (& local_args );
72
77
73
78
snrt_global_barrier ();
74
79
@@ -92,28 +97,32 @@ static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t l) {
92
97
}
93
98
snrt_cluster_hw_barrier ();
94
99
95
- gemm_args_t gemm_args = {1.0 ,
96
- l .dtype ,
97
- 0 ,
98
- 0 ,
99
- 1 ,
100
- 1 ,
101
- 1 ,
102
- l .num_inputs ,
103
- 0 ,
104
- 1 ,
105
- 1 ,
106
- 0 ,
107
- 0 ,
108
- m ,
109
- n ,
110
- concat_k ,
111
- a ,
112
- l .weights ,
113
- 0 ,
114
- l .linear_output ,
115
- l .gemm_implementation };
116
- gemm (& gemm_args );
100
+ gemm_args_t gemm_args ;
101
+ gemm_args_t * local_args = (gemm_args_t * )& gemm_args ;
102
+
103
+ local_args -> alpha = 1.0 ;
104
+ local_args -> prec = l .dtype ;
105
+ local_args -> setup_ssr = 0 ;
106
+ local_args -> parallelize_m = 0 ;
107
+ local_args -> parallelize_k = 1 ;
108
+ local_args -> m_tiles = 1 ;
109
+ local_args -> n_tiles = 1 ;
110
+ local_args -> k_tiles = l .num_inputs ;
111
+ local_args -> load_a = 0 ;
112
+ local_args -> load_b = 1 ;
113
+ local_args -> load_c = 1 ;
114
+ local_args -> transa = 0 ;
115
+ local_args -> transb = 0 ;
116
+ local_args -> M = m ;
117
+ local_args -> N = n ;
118
+ local_args -> K = concat_k ;
119
+ local_args -> a = a ;
120
+ local_args -> b = l .weights ;
121
+ local_args -> beta = 0 ;
122
+ local_args -> c = l .linear_output ;
123
+ local_args -> gemm_fp = l .gemm_implementation ;
124
+
125
+ gemm (& local_args );
117
126
118
127
snrt_global_barrier ();
119
128
0 commit comments