@@ -48,32 +48,29 @@ static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) {
48
48
uint32_t k = l .input_shape [1 ] * l .num_inputs ;
49
49
uint32_t n = l .output_shape [1 ];
50
50
51
- gemm_args_t gemm_args ;
52
- gemm_args_t * local_args = (gemm_args_t * )& gemm_args ;
53
-
54
- local_args -> alpha = 1.0 ;
55
- local_args -> prec = l .dtype ;
56
- local_args -> setup_ssr = 0 ;
57
- local_args -> parallelize_m = 1 ;
58
- local_args -> parallelize_k = 0 ;
59
- local_args -> m_tiles = snrt_cluster_core_num ();
60
- local_args -> n_tiles = 1 ;
61
- local_args -> k_tiles = 1 ;
62
- local_args -> load_a = 0 ;
63
- local_args -> load_b = 1 ;
64
- local_args -> load_c = 1 ;
65
- local_args -> transa = 0 ;
66
- local_args -> transb = 0 ;
67
- local_args -> M = m ;
68
- local_args -> N = n ;
69
- local_args -> K = k ;
70
- local_args -> a = l .concat_output ;
71
- local_args -> b = l .weights ;
72
- local_args -> beta = 0 ;
73
- local_args -> c = l .linear_output ;
74
- local_args -> gemm_fp = l .gemm_implementation ;
75
-
76
- gemm (& local_args );
51
+ gemm_args_t gemm_args = {.alpha = 1.0 ,
52
+ .prec = l .dtype ,
53
+ .setup_ssr = 0 ,
54
+ .parallelize_m = 1 ,
55
+ .parallelize_k = 0 ,
56
+ .m_tiles = snrt_cluster_num (),
57
+ .n_tiles = 1 ,
58
+ .k_tiles = 1 ,
59
+ .load_a = 0 ,
60
+ .load_b = 1 ,
61
+ .load_c = 1 ,
62
+ .transa = 0 ,
63
+ .transb = 0 ,
64
+ .M = m ,
65
+ .N = n ,
66
+ .K = k ,
67
+ .a = l .concat_output ,
68
+ .b = l .weights ,
69
+ .beta = 0 ,
70
+ .c = l .linear_output ,
71
+ .gemm_fp = l .gemm_implementation };
72
+
73
+ gemm (& gemm_args );
77
74
78
75
snrt_global_barrier ();
79
76
@@ -97,32 +94,29 @@ static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t l) {
97
94
}
98
95
snrt_cluster_hw_barrier ();
99
96
100
- gemm_args_t gemm_args ;
101
- gemm_args_t * local_args = (gemm_args_t * )& gemm_args ;
102
-
103
- local_args -> alpha = 1.0 ;
104
- local_args -> prec = l .dtype ;
105
- local_args -> setup_ssr = 0 ;
106
- local_args -> parallelize_m = 0 ;
107
- local_args -> parallelize_k = 1 ;
108
- local_args -> m_tiles = 1 ;
109
- local_args -> n_tiles = 1 ;
110
- local_args -> k_tiles = l .num_inputs ;
111
- local_args -> load_a = 0 ;
112
- local_args -> load_b = 1 ;
113
- local_args -> load_c = 1 ;
114
- local_args -> transa = 0 ;
115
- local_args -> transb = 0 ;
116
- local_args -> M = m ;
117
- local_args -> N = n ;
118
- local_args -> K = concat_k ;
119
- local_args -> a = a ;
120
- local_args -> b = l .weights ;
121
- local_args -> beta = 0 ;
122
- local_args -> c = l .linear_output ;
123
- local_args -> gemm_fp = l .gemm_implementation ;
124
-
125
- gemm (& local_args );
97
+ gemm_args_t gemm_args = {.alpha = 1.0 ,
98
+ .prec = l .dtype ,
99
+ .setup_ssr = 0 ,
100
+ .parallelize_m = 0 ,
101
+ .parallelize_k = 1 ,
102
+ .m_tiles = 1 ,
103
+ .n_tiles = 1 ,
104
+ .k_tiles = l .num_inputs ,
105
+ .load_a = 0 ,
106
+ .load_b = 1 ,
107
+ .load_c = 1 ,
108
+ .transa = 0 ,
109
+ .transb = 0 ,
110
+ .M = m ,
111
+ .N = n ,
112
+ .K = concat_k ,
113
+ .a = a ,
114
+ .b = l .weights ,
115
+ .beta = 0 ,
116
+ .c = l .linear_output ,
117
+ .gemm_fp = l .gemm_implementation };
118
+
119
+ gemm (& gemm_args );
126
120
127
121
snrt_global_barrier ();
128
122
0 commit comments