Skip to content

Commit 6a2f112

Browse files
committed
sw: Validate GEMM, Layernorm and FA-2 tile footprints in TCDM
1 parent 6fcea3a commit 6a2f112

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

sw/blas/gemm/scripts/datagen.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,21 @@ def validate_config(self, gemm_fp, parallelize_m,
5151
transb, M, N, K, beta, **kwargs):
5252
frac_m = M / m_tiles
5353
frac_n = N / n_tiles
54+
frac_k = K / k_tiles
5455

5556
dtype, impl = self.infer_implementation(gemm_fp)
5657

58+
# Calculate total TCDM occupation
59+
# Note: doesn't account for double buffering
60+
prec = data_utils.size_from_precision_t(dtype)
61+
a_size = frac_m * frac_k * prec
62+
b_size = frac_k * frac_n * prec
63+
c_size = frac_m * frac_n * prec
64+
total_size = a_size
65+
total_size += b_size
66+
total_size += c_size
67+
data_utils.validate_tcdm_footprint(total_size)
68+
5769
assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size'
5870
assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size'
5971
assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size'

sw/dnn/flashattention_2/scripts/datagen.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import os
1515
import torch
1616
import pyflexfloat as ff
17-
import humanize
1817

1918
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
2019
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../blas/"))
@@ -31,8 +30,6 @@
3130
# the occurrence of these splits the data should be aligned to 4KB
3231
BURST_ALIGNMENT = 4096
3332

34-
# Maximum available size in TCDM (in bytes)
35-
L1_HEAP_SIZE = 112 * 1024
3633

3734
def torch_golden_model(Q, K, V):
3835
return torch.nn.functional.scaled_dot_product_attention(Q, K, V)
@@ -169,10 +166,7 @@ def validate_config(L, S, d, B_r, B_c, dtype, baseline, gemm_impl):
169166
total_size += o_fa_size
170167
total_size += m_i_size * 2 # m_i and m_i_prev
171168
total_size += l_i_size
172-
assert total_size < L1_HEAP_SIZE, \
173-
f'Total heap space required {humanize.naturalsize(total_size, binary=True)} exceeds ' \
174-
f'limit of {humanize.naturalsize(L1_HEAP_SIZE, binary=True)}'
175-
print(f'Total heap space required {humanize.naturalsize(total_size, binary=True)}')
169+
data_utils.validate_tcdm_footprint(total_size)
176170

177171
# Q*K^t
178172
gemm.datagen.GemmDataGen().validate_config(

sw/dnn/layernorm/scripts/datagen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ def golden_model_torch(ifmap, eps, shape):
4949

5050

5151
def validate_config(**kwargs):
52+
# Aliases
53+
batch_size = kwargs['input_dim']['batch_size']
54+
seq_len = kwargs['input_dim']['seq_len']
55+
embeddings = kwargs['input_dim']['embeddings']
56+
57+
# Calculate total TCDM occupation
58+
prec = data_utils.size_from_precision_t(kwargs['prec'])
59+
tiled_seq_len = seq_len / kwargs['n_tiles']
60+
total_size = batch_size * tiled_seq_len * embeddings * prec
61+
data_utils.validate_tcdm_footprint(total_size)
62+
5263
assert kwargs['input_dim']['seq_len'] % kwargs['n_tiles'] == 0, 'Input dimension is not' \
5364
' an integer multiple of' \
5465
' tile size'

0 commit comments

Comments
 (0)