Skip to content

Commit 4682694

Browse files
committed
gemm: Pass implementation function as parameter
1 parent 749673b commit 4682694

26 files changed

+522
-440
lines changed

sw/blas/gemm/data/params.json

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
// Copyright 2023 ETH Zurich and University of Bologna.
2-
// Solderpad Hardware License, Version 0.51, see LICENSE for details.
3-
// SPDX-License-Identifier: SHL-0.51
4-
5-
// Parameters for a GEMM
1+
// Copyright 2024 ETH Zurich and University of Bologna.
2+
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
// SPDX-License-Identifier: Apache-2.0
64

75
{
8-
M: 192,
6+
prec: "FP32",
7+
setup_ssr: 1,
8+
parallelize_m: 0,
9+
parallelize_k: 0,
10+
m_tiles: 2, // number of tiles in M dimension
11+
n_tiles: 1, // number of tiles in N dimension
12+
k_tiles: 1, // number of tiles in K dimension
13+
load_a: 1,
14+
load_b: 1,
15+
load_c: 1,
16+
transa: false,
17+
transb: true, // must be true for SIMD
18+
M: 16,
919
N: 16,
1020
K: 16,
21+
alpha: 1,
1122
beta: 0,
12-
ta: false,
13-
tb: true, // must be true for SIMD
14-
prec: "FP64",
15-
expand: 0,
16-
m_tiles: 2, // number of tiles in M dimension
17-
k_tiles: 1, // number of tiles in K dimension
18-
n_tiles: 1, // number of tiles in N dimension
19-
parallelize_k: 0,
20-
parallelize_m: 0,
21-
implementation: "NAIVE"
23+
gemm_fp: "gemm_fp32_opt"
2224
}

sw/blas/gemm/scripts/datagen.py

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,20 @@
88
# Viviane Potocnik <vivianep@iis.ee.ethz.ch>
99

1010
import numpy as np
11-
import argparse
12-
import pathlib
13-
import json5
1411
import sys
1512
import os
13+
import re
1614
import pyflexfloat as ff
1715

1816
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
1917
import data_utils # noqa: E402
20-
from data_utils import emit_license, format_scalar_definition, \
18+
from data_utils import DataGen, format_array_declaration, format_struct_definition, \
2119
format_array_definition, format_ifdef_wrapper # noqa: E402
2220

2321

2422
np.random.seed(42)
2523

24+
2625
class GemmDataGen(DataGen):
2726

2827
# AXI splits bursts crossing 4KB address boundaries. To minimize
@@ -41,51 +40,59 @@ def exact_golden_model(self, alpha, a, b, beta, c):
4140
result[m][n] += a[m][k] * b[k][n]
4241
return result
4342

44-
def validate_config(self, prec, implementation, parallelize_m, parallelize_k, m_tiles, n_tiles, k_tiles, ta,
45-
tb, M, N, K, beta, **kwargs):
43+
def infer_implementation(self, gemm_fp):
44+
# gemm_fp: "gemm_fp64_opt"
45+
# create a regex with fp_<type>_<implementation>
46+
prec, impl = re.search(r'gemm_fp(\d+)_(\w+)', gemm_fp).group(1, 2)
47+
return (int(prec) / 8), impl
48+
49+
def validate_config(self, gemm_fp, parallelize_m,
50+
parallelize_k, m_tiles, n_tiles, k_tiles, transa,
51+
transb, M, N, K, beta, **kwargs):
4652
frac_m = M / m_tiles
4753
frac_n = N / n_tiles
4854

55+
dtype, impl = self.infer_implementation(gemm_fp)
56+
4957
assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size'
5058
assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size'
5159
assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size'
5260
assert (frac_m % 8) == 0, 'frac_m is not an integer multiple of the number of cores per' \
53-
' cluster'
61+
' cluster'
5462
assert not (parallelize_m and parallelize_k), 'Cannot parallelize K and M simultaneously'
55-
assert not ta, 'SIMD kernels don\'t support transposed A matrix'
56-
assert (prec == "FP64") or (implementation == 'BASELINE') or (implementation == 'NAIVE') \
57-
or tb, 'Optimized SIMD kernels only support transposed B matrix'
58-
assert not tb or n_tiles == 1, 'Tiling in the N dimension supported only if B is' \
59-
' not transposed'
60-
assert not tb or k_tiles == 1, 'Tiling in the K dimension supported only if B is' \
61-
' not transposed'
62-
assert (implementation == 'BASELINE') or (implementation == 'NAIVE') or frac_n >= 8, \
63-
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
64-
'when using optimized kernels'
63+
assert not transa, 'SIMD kernels don\'t support transposed A matrix'
64+
assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \
65+
or transb, 'Optimized SIMD kernels only support transposed B matrix'
66+
assert not transb or n_tiles == 1, 'Tiling in the N dimension not supported' \
67+
' if B is transposed'
68+
assert not transb or k_tiles == 1, 'Tiling in the K dimension not supported' \
69+
' if B is transposed'
70+
assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \
71+
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
72+
'when using optimized kernels'
6573
assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta'
66-
assert not (prec == "FP64" and implementation == "BASELINE"), 'No baseline implemented' \
67-
' for FP64 (switch to NAIVE)'
68-
assert not (((prec == "FP64") or (prec == "FP32")) and implementation == "OPT_EX"), \
74+
assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \
75+
' for FP64 (switch to NAIVE)'
76+
assert not (((dtype == 8) or (dtype == 4)) and impl == "OPT_EX"), \
6977
'Expanding GEMM kernels' \
7078
' not supported for FP64 and FP32'
71-
assert not (((prec == "FP16") or (prec == "FP8")) and implementation == "NAIVE"), \
79+
assert not (((dtype == 2) or (dtype == 1)) and impl == "NAIVE"), \
7280
'FP16 and FP8 not supported' \
7381
' in naive implementation'
74-
assert not (prec == "FP8" and implementation == "OPT"), 'FP8 not supported in' \
75-
' optimized implementation' \
76-
' (switch to OPT_EX)'
77-
82+
assert not (dtype == 1 and impl == "OPT"), 'FP8 not supported in' \
83+
' optimized implementation' \
84+
' (switch to OPT_EX)'
7885

7986
def emit_header(self, **kwargs):
8087
header = [super().emit_header()]
8188

8289
# Validate parameters
8390
self.validate_config(**kwargs)
8491

85-
# Generate random input matrices
86-
prec = kwargs['prec']
8792
M, N, K = kwargs['M'], kwargs['N'], kwargs['K']
8893

94+
prec, _ = self.infer_implementation(kwargs['gemm_fp'])
95+
8996
ff_desc = data_utils.ff_desc_from_precision_t(prec)
9097
ctype = data_utils.ctype_from_precision_t(prec)
9198

@@ -95,28 +102,34 @@ def emit_header(self, **kwargs):
95102
result = self.exact_golden_model(1, a, b, kwargs['beta'], c)
96103

97104
# Store matrices in transposed form if requested
98-
a = a.T if kwargs['ta'] else a
99-
b = b.T if kwargs['tb'] else b
100-
101-
header += [format_scalar_definition('uint32_t', 'M', M)]
102-
header += [format_scalar_definition('uint32_t', 'N', N)]
103-
header += [format_scalar_definition('uint32_t', 'K', K)]
104-
header += [format_scalar_definition('uint32_t', 'TA', int(kwargs['ta']))]
105-
header += [format_scalar_definition('uint32_t', 'TB', int(kwargs['tb']))]
106-
header += [format_scalar_definition('uint32_t', 'BETA', kwargs['beta'])]
107-
header += [format_scalar_definition('uint32_t', 'dtype_size', prec)]
108-
header += [format_scalar_definition('uint32_t', 'expand', int(kwargs['expand']))]
109-
header += [format_scalar_definition('uint32_t', 'm_tiles', kwargs['m_tiles'])]
110-
header += [format_scalar_definition('uint32_t', 'n_tiles', kwargs['n_tiles'])]
111-
header += [format_scalar_definition('uint32_t', 'k_tiles', kwargs['k_tiles'])]
112-
header += [format_scalar_definition('uint32_t', 'parallelize_m', kwargs['parallelize_m'])]
113-
header += [format_scalar_definition('uint32_t', 'parallelize_k', kwargs['parallelize_k'])]
114-
header += [format_scalar_definition('implementation_t', 'implementation', kwargs['implementation'])]
115-
header += [format_array_definition(ctype, 'a', a.flatten(), alignment=self.BURST_ALIGNMENT,
105+
a = a.T if kwargs['transa'] else a
106+
b = b.T if kwargs['transb'] else b
107+
108+
a_uid = 'a'
109+
b_uid = 'b'
110+
c_uid = 'c'
111+
112+
cfg = {
113+
'prec': prec,
114+
**kwargs,
115+
'a': a_uid,
116+
'b': b_uid,
117+
'c': c_uid,
118+
}
119+
120+
a = a.flatten()
121+
b = b.flatten()
122+
c = c.flatten()
123+
124+
header += [format_array_declaration(ctype, a_uid, a.shape)]
125+
header += [format_array_declaration(ctype, b_uid, b.shape)]
126+
header += [format_array_declaration(ctype, c_uid, c.shape)]
127+
header += [format_struct_definition('gemm_args_t', 'args', cfg)]
128+
header += [format_array_definition(ctype, a_uid, a,
116129
section=kwargs['section'])]
117-
header += [format_array_definition(ctype, 'b', b.flatten(), alignment=self.BURST_ALIGNMENT,
130+
header += [format_array_definition(ctype, b_uid, b,
118131
section=kwargs['section'])]
119-
header += [format_array_definition(ctype, 'c', c.flatten(), alignment=self.BURST_ALIGNMENT,
132+
header += [format_array_definition(ctype, c_uid, c,
120133
section=kwargs['section'])]
121134
result_def = format_array_definition(ctype, 'result', result.flatten())
122135
header += [format_ifdef_wrapper('BIST', result_def)]
@@ -125,5 +138,5 @@ def emit_header(self, **kwargs):
125138
return header
126139

127140

128-
if __name__ == '__main__':
129-
main()
141+
if __name__ == "__main__":
142+
sys.exit(GemmDataGen().main())

sw/blas/gemm/scripts/verify.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import sys
1010
from pathlib import Path
11-
from datagen import golden_model
11+
from datagen import GemmDataGen
1212

1313
sys.path.append(str(Path(__file__).parent / '../../../../util/sim/'))
1414
from verif_utils import Verifier # noqa: E402
@@ -19,27 +19,54 @@ class GemmVerifier(Verifier):
1919

2020
OUTPUT_UIDS = ['c']
2121
ERR_THRESHOLD = {
22-
0: {8: 1e-6, 4: 1e-6, 2: 1e-2, 1: 1e-4},
23-
1: {8: 0, 4: 0, 2: 0, 1: 0}
22+
1: 1e-4,
23+
2: 1e-2,
24+
4: 1e-6,
25+
8: 1e-6
2426
}
2527

2628
def __init__(self):
2729
super().__init__()
28-
self.prec = self.get_input_from_symbol('dtype_size', 'uint32_t')[0]
29-
self.baseline = self.get_input_from_symbol('baseline', 'uint32_t')[0]
30+
self.func_args = {
31+
'alpha': 'd',
32+
'prec': 'I',
33+
'setup_ssr': 'I',
34+
'parallelize_m': 'I',
35+
'parallelize_k': 'I',
36+
'm_tiles': 'I',
37+
'n_tiles': 'I',
38+
'k_tiles': 'I',
39+
'load_a': 'I',
40+
'load_b': 'I',
41+
'load_c': 'I',
42+
'transa': 'I',
43+
'transb': 'I',
44+
'M': 'I',
45+
'N': 'I',
46+
'K': 'I',
47+
'a': 'I',
48+
'b': 'I',
49+
'beta': 'I',
50+
'c': 'I',
51+
'gemm_fp': 'I'
52+
}
53+
self.func_args = self.get_input_from_symbol('args', self.func_args)
3054

3155
def get_actual_results(self):
32-
return self.get_output_from_symbol(self.OUTPUT_UIDS[0], ctype_from_precision_t(self.prec))
56+
prec = self.func_args['prec']
57+
return self.get_output_from_symbol(self.OUTPUT_UIDS[0], ctype_from_precision_t(prec))
3358

3459
def get_expected_results(self):
35-
a = self.get_input_from_symbol('a', ctype_from_precision_t(self.prec))
36-
b = self.get_input_from_symbol('b', ctype_from_precision_t(self.prec))
37-
c = self.get_input_from_symbol('c', ctype_from_precision_t(self.prec))
38-
beta = self.get_input_from_symbol('BETA', 'uint32_t')[0]
39-
m = self.get_input_from_symbol('M', 'uint32_t')[0]
40-
n = self.get_input_from_symbol('N', 'uint32_t')[0]
41-
k = self.get_input_from_symbol('K', 'uint32_t')[0]
42-
tb = self.get_input_from_symbol('TB', 'uint32_t')[0]
60+
prec = self.func_args['prec']
61+
a = self.get_input_from_symbol('a', ctype_from_precision_t(prec))
62+
b = self.get_input_from_symbol('b', ctype_from_precision_t(prec))
63+
c = self.get_input_from_symbol('c', ctype_from_precision_t(prec))
64+
beta = self.func_args['beta']
65+
m = self.func_args['M']
66+
n = self.func_args['N']
67+
k = self.func_args['K']
68+
tb = self.func_args['transb']
69+
4370
a = np.reshape(a, (m, k))
4471
if tb:
4572
b = np.reshape(b, (n, k))
@@ -50,7 +77,8 @@ def get_expected_results(self):
5077
return GemmDataGen().exact_golden_model(1, a, b, beta, c).flatten()
5178

5279
def check_results(self, *args):
53-
return super().check_results(*args, rtol=self.ERR_THRESHOLD[self.baseline][self.prec])
80+
prec = self.func_args['prec']
81+
return super().check_results(*args, rtol=self.ERR_THRESHOLD[prec])
5482

5583

5684
if __name__ == "__main__":

0 commit comments

Comments
 (0)