8
8
# Viviane Potocnik <vivianep@iis.ee.ethz.ch>
9
9
10
10
import numpy as np
11
- import argparse
12
- import pathlib
13
- import json5
14
11
import sys
15
12
import os
13
+ import re
16
14
import pyflexfloat as ff
17
15
18
16
sys .path .append (os .path .join (os .path .dirname (__file__ ), "../../../../util/sim/" ))
19
17
import data_utils # noqa: E402
20
- from data_utils import emit_license , format_scalar_definition , \
18
+ from data_utils import DataGen , format_array_declaration , format_struct_definition , \
21
19
format_array_definition , format_ifdef_wrapper # noqa: E402
22
20
23
21
24
22
np .random .seed (42 )
25
23
24
+
26
25
class GemmDataGen (DataGen ):
27
26
28
27
# AXI splits bursts crossing 4KB address boundaries. To minimize
@@ -41,51 +40,59 @@ def exact_golden_model(self, alpha, a, b, beta, c):
41
40
result [m ][n ] += a [m ][k ] * b [k ][n ]
42
41
return result
43
42
44
- def validate_config (self , prec , implementation , parallelize_m , parallelize_k , m_tiles , n_tiles , k_tiles , ta ,
45
- tb , M , N , K , beta , ** kwargs ):
43
+ def infer_implementation (self , gemm_fp ):
44
+ # gemm_fp: "gemm_fp64_opt"
45
+ # create a regex with fp_<type>_<implementation>
46
+ prec , impl = re .search (r'gemm_fp(\d+)_(\w+)' , gemm_fp ).group (1 , 2 )
47
+ return (int (prec ) / 8 ), impl
48
+
49
+ def validate_config (self , gemm_fp , parallelize_m ,
50
+ parallelize_k , m_tiles , n_tiles , k_tiles , transa ,
51
+ transb , M , N , K , beta , ** kwargs ):
46
52
frac_m = M / m_tiles
47
53
frac_n = N / n_tiles
48
54
55
+ dtype , impl = self .infer_implementation (gemm_fp )
56
+
49
57
assert (M % m_tiles ) == 0 , 'M is not an integer multiple of tile size'
50
58
assert (N % n_tiles ) == 0 , 'N is not an integer multiple of tile size'
51
59
assert (K % k_tiles ) == 0 , 'K is not an integer multiple of tile size'
52
60
assert (frac_m % 8 ) == 0 , 'frac_m is not an integer multiple of the number of cores per' \
53
- ' cluster'
61
+ ' cluster'
54
62
assert not (parallelize_m and parallelize_k ), 'Cannot parallelize K and M simultaneously'
55
- assert not ta , 'SIMD kernels don\' t support transposed A matrix'
56
- assert (prec == "FP64" ) or (implementation == 'BASELINE ' ) or (implementation == 'NAIVE ' ) \
57
- or tb , 'Optimized SIMD kernels only support transposed B matrix'
58
- assert not tb or n_tiles == 1 , 'Tiling in the N dimension supported only if B is ' \
59
- ' not transposed'
60
- assert not tb or k_tiles == 1 , 'Tiling in the K dimension supported only if B is ' \
61
- ' not transposed'
62
- assert (implementation == 'BASELINE ' ) or (implementation == 'NAIVE ' ) or frac_n >= 8 , \
63
- 'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
64
- 'when using optimized kernels'
63
+ assert not transa , 'SIMD kernels don\' t support transposed A matrix'
64
+ assert (dtype == 8 ) or (impl == 'baseline ' ) or (impl == 'naive ' ) \
65
+ or transb , 'Optimized SIMD kernels only support transposed B matrix'
66
+ assert not transb or n_tiles == 1 , 'Tiling in the N dimension not supported ' \
67
+ ' if B is transposed'
68
+ assert not transb or k_tiles == 1 , 'Tiling in the K dimension not supported ' \
69
+ ' if B is transposed'
70
+ assert (impl == 'baseline ' ) or (impl == 'naive ' ) or frac_n >= 8 , \
71
+ 'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
72
+ 'when using optimized kernels'
65
73
assert beta == 0 or beta == 1 , 'Only values of 0 or 1 supported for beta'
66
- assert not (prec == "FP64" and implementation == "BASELINE " ), 'No baseline implemented' \
67
- ' for FP64 (switch to NAIVE)'
68
- assert not (((prec == "FP64" ) or (prec == "FP32" )) and implementation == "OPT_EX" ), \
74
+ assert not (dtype == 8 and impl == "baseline " ), 'No baseline implemented' \
75
+ ' for FP64 (switch to NAIVE)'
76
+ assert not (((dtype == 8 ) or (dtype == 4 )) and impl == "OPT_EX" ), \
69
77
'Expanding GEMM kernels' \
70
78
' not supported for FP64 and FP32'
71
- assert not (((prec == "FP16" ) or (prec == "FP8" )) and implementation == "NAIVE" ), \
79
+ assert not (((dtype == 2 ) or (dtype == 1 )) and impl == "NAIVE" ), \
72
80
'FP16 and FP8 not supported' \
73
81
' in naive implementation'
74
- assert not (prec == "FP8" and implementation == "OPT" ), 'FP8 not supported in' \
75
- ' optimized implementation' \
76
- ' (switch to OPT_EX)'
77
-
82
+ assert not (dtype == 1 and impl == "OPT" ), 'FP8 not supported in' \
83
+ ' optimized implementation' \
84
+ ' (switch to OPT_EX)'
78
85
79
86
def emit_header (self , ** kwargs ):
80
87
header = [super ().emit_header ()]
81
88
82
89
# Validate parameters
83
90
self .validate_config (** kwargs )
84
91
85
- # Generate random input matrices
86
- prec = kwargs ['prec' ]
87
92
M , N , K = kwargs ['M' ], kwargs ['N' ], kwargs ['K' ]
88
93
94
+ prec , _ = self .infer_implementation (kwargs ['gemm_fp' ])
95
+
89
96
ff_desc = data_utils .ff_desc_from_precision_t (prec )
90
97
ctype = data_utils .ctype_from_precision_t (prec )
91
98
@@ -95,28 +102,34 @@ def emit_header(self, **kwargs):
95
102
result = self .exact_golden_model (1 , a , b , kwargs ['beta' ], c )
96
103
97
104
# Store matrices in transposed form if requested
98
- a = a .T if kwargs ['ta' ] else a
99
- b = b .T if kwargs ['tb' ] else b
100
-
101
- header += [format_scalar_definition ('uint32_t' , 'M' , M )]
102
- header += [format_scalar_definition ('uint32_t' , 'N' , N )]
103
- header += [format_scalar_definition ('uint32_t' , 'K' , K )]
104
- header += [format_scalar_definition ('uint32_t' , 'TA' , int (kwargs ['ta' ]))]
105
- header += [format_scalar_definition ('uint32_t' , 'TB' , int (kwargs ['tb' ]))]
106
- header += [format_scalar_definition ('uint32_t' , 'BETA' , kwargs ['beta' ])]
107
- header += [format_scalar_definition ('uint32_t' , 'dtype_size' , prec )]
108
- header += [format_scalar_definition ('uint32_t' , 'expand' , int (kwargs ['expand' ]))]
109
- header += [format_scalar_definition ('uint32_t' , 'm_tiles' , kwargs ['m_tiles' ])]
110
- header += [format_scalar_definition ('uint32_t' , 'n_tiles' , kwargs ['n_tiles' ])]
111
- header += [format_scalar_definition ('uint32_t' , 'k_tiles' , kwargs ['k_tiles' ])]
112
- header += [format_scalar_definition ('uint32_t' , 'parallelize_m' , kwargs ['parallelize_m' ])]
113
- header += [format_scalar_definition ('uint32_t' , 'parallelize_k' , kwargs ['parallelize_k' ])]
114
- header += [format_scalar_definition ('implementation_t' , 'implementation' , kwargs ['implementation' ])]
115
- header += [format_array_definition (ctype , 'a' , a .flatten (), alignment = self .BURST_ALIGNMENT ,
105
+ a = a .T if kwargs ['transa' ] else a
106
+ b = b .T if kwargs ['transb' ] else b
107
+
108
+ a_uid = 'a'
109
+ b_uid = 'b'
110
+ c_uid = 'c'
111
+
112
+ cfg = {
113
+ 'prec' : prec ,
114
+ ** kwargs ,
115
+ 'a' : a_uid ,
116
+ 'b' : b_uid ,
117
+ 'c' : c_uid ,
118
+ }
119
+
120
+ a = a .flatten ()
121
+ b = b .flatten ()
122
+ c = c .flatten ()
123
+
124
+ header += [format_array_declaration (ctype , a_uid , a .shape )]
125
+ header += [format_array_declaration (ctype , b_uid , b .shape )]
126
+ header += [format_array_declaration (ctype , c_uid , c .shape )]
127
+ header += [format_struct_definition ('gemm_args_t' , 'args' , cfg )]
128
+ header += [format_array_definition (ctype , a_uid , a ,
116
129
section = kwargs ['section' ])]
117
- header += [format_array_definition (ctype , 'b' , b . flatten (), alignment = self . BURST_ALIGNMENT ,
130
+ header += [format_array_definition (ctype , b_uid , b ,
118
131
section = kwargs ['section' ])]
119
- header += [format_array_definition (ctype , 'c' , c . flatten (), alignment = self . BURST_ALIGNMENT ,
132
+ header += [format_array_definition (ctype , c_uid , c ,
120
133
section = kwargs ['section' ])]
121
134
result_def = format_array_definition (ctype , 'result' , result .flatten ())
122
135
header += [format_ifdef_wrapper ('BIST' , result_def )]
@@ -125,5 +138,5 @@ def emit_header(self, **kwargs):
125
138
return header
126
139
127
140
128
- if __name__ == ' __main__' :
129
- main ()
141
+ if __name__ == " __main__" :
142
+ sys . exit ( GemmDataGen (). main () )
0 commit comments