Skip to content

Commit 703de24

Browse files
author
Viviane Potocnik
committed
[sw] DNN: fix header include order in DNN header
[sw] FCL: fix struct declaration
1 parent 2b8c554 commit 703de24

File tree

4 files changed

+48
-56
lines changed

4 files changed

+48
-56
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ jobs:
7777
# yamllint enable rule:line-length
7878
match_regex: true
7979
exclude_paths: |
80-
sw/dnn/flashattention_2/src/flashattention_2.h
8180
sw/snRuntime/src/omp/interface.h
8281
sw/math/arch/generic/*
8382
sw/math/arch/riscv64/bits/*
@@ -133,7 +132,7 @@ jobs:
133132
with:
134133
exclude: |
135134
./sw/saris
136-
./sw/dnn/flashattention_2/src/flashattention_2.h
135+
./sw/dnn/src/dnn.h
137136
clangFormatVersion: 10
138137

139138
######################

sw/dnn/flashattention_2/src/flashattention_2.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ typedef struct {
4848
void *gemm_implementation;
4949
} flashattention_2_layer_t;
5050

51-
#include "../transpose/src/transpose.h"
5251
#include "../flashattention_2/src/flashattention_2_fp16.h"
5352
#include "../flashattention_2/src/flashattention_2_fp32.h"
5453
#include "../flashattention_2/src/flashattention_2_fp8.h"

sw/dnn/fused_concat_linear/src/fused_concat_linear.h

Lines changed: 46 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -48,32 +48,29 @@ static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) {
4848
uint32_t k = l.input_shape[1] * l.num_inputs;
4949
uint32_t n = l.output_shape[1];
5050

51-
gemm_args_t gemm_args;
52-
gemm_args_t *local_args = (gemm_args_t *)&gemm_args;
53-
54-
local_args->alpha = 1.0;
55-
local_args->prec = l.dtype;
56-
local_args->setup_ssr = 0;
57-
local_args->parallelize_m = 1;
58-
local_args->parallelize_k = 0;
59-
local_args->m_tiles = snrt_cluster_core_num();
60-
local_args->n_tiles = 1;
61-
local_args->k_tiles = 1;
62-
local_args->load_a = 0;
63-
local_args->load_b = 1;
64-
local_args->load_c = 1;
65-
local_args->transa = 0;
66-
local_args->transb = 0;
67-
local_args->M = m;
68-
local_args->N = n;
69-
local_args->K = k;
70-
local_args->a = l.concat_output;
71-
local_args->b = l.weights;
72-
local_args->beta = 0;
73-
local_args->c = l.linear_output;
74-
local_args->gemm_fp = l.gemm_implementation;
75-
76-
gemm(&local_args);
51+
gemm_args_t gemm_args = {.alpha = 1.0,
52+
.prec = l.dtype,
53+
.setup_ssr = 0,
54+
.parallelize_m = 1,
55+
.parallelize_k = 0,
56+
.m_tiles = snrt_cluster_num(),
57+
.n_tiles = 1,
58+
.k_tiles = 1,
59+
.load_a = 0,
60+
.load_b = 1,
61+
.load_c = 1,
62+
.transa = 0,
63+
.transb = 0,
64+
.M = m,
65+
.N = n,
66+
.K = k,
67+
.a = l.concat_output,
68+
.b = l.weights,
69+
.beta = 0,
70+
.c = l.linear_output,
71+
.gemm_fp = l.gemm_implementation};
72+
73+
gemm(&gemm_args);
7774

7875
snrt_global_barrier();
7976

@@ -97,32 +94,29 @@ static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t l) {
9794
}
9895
snrt_cluster_hw_barrier();
9996

100-
gemm_args_t gemm_args;
101-
gemm_args_t *local_args = (gemm_args_t *)&gemm_args;
102-
103-
local_args->alpha = 1.0;
104-
local_args->prec = l.dtype;
105-
local_args->setup_ssr = 0;
106-
local_args->parallelize_m = 0;
107-
local_args->parallelize_k = 1;
108-
local_args->m_tiles = 1;
109-
local_args->n_tiles = 1;
110-
local_args->k_tiles = l.num_inputs;
111-
local_args->load_a = 0;
112-
local_args->load_b = 1;
113-
local_args->load_c = 1;
114-
local_args->transa = 0;
115-
local_args->transb = 0;
116-
local_args->M = m;
117-
local_args->N = n;
118-
local_args->K = concat_k;
119-
local_args->a = a;
120-
local_args->b = l.weights;
121-
local_args->beta = 0;
122-
local_args->c = l.linear_output;
123-
local_args->gemm_fp = l.gemm_implementation;
124-
125-
gemm(&local_args);
97+
gemm_args_t gemm_args = {.alpha = 1.0,
98+
.prec = l.dtype,
99+
.setup_ssr = 0,
100+
.parallelize_m = 0,
101+
.parallelize_k = 1,
102+
.m_tiles = 1,
103+
.n_tiles = 1,
104+
.k_tiles = l.num_inputs,
105+
.load_a = 0,
106+
.load_b = 1,
107+
.load_c = 1,
108+
.transa = 0,
109+
.transb = 0,
110+
.M = m,
111+
.N = n,
112+
.K = concat_k,
113+
.a = a,
114+
.b = l.weights,
115+
.beta = 0,
116+
.c = l.linear_output,
117+
.gemm_fp = l.gemm_implementation};
118+
119+
gemm(&gemm_args);
126120

127121
snrt_global_barrier();
128122

sw/dnn/src/dnn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ typedef struct network_single_cluster_t_ {
201201
#include "../batchnorm/src/batchnorm.h"
202202
#include "../concat/src/concat.h"
203203
#include "../conv2d/src/conv2d.h"
204+
#include "../transpose/src/transpose.h"
204205
#include "../flashattention_2/src/flashattention_2.h"
205206
#include "../fused_concat_linear/src/fused_concat_linear.h"
206207
#include "../gelu/src/gelu.h"
207208
#include "../layernorm/src/layernorm.h"
208209
#include "../maxpool/src/maxpool.h"
209210
#include "../softmax/src/softmax.h"
210-
#include "../transpose/src/transpose.h"

0 commit comments

Comments
 (0)