Skip to content

Commit 1df7825

Browse files
collucaViviane Potocnik
authored and
Viviane Potocnik
committed
flashattention_2: Add low-precision implementations
flashattention_2: Correct bug in datagen to support more sizes flashattention_2: Add mcycle calls for benchmarking flashattention_2: Add TCASAI experiments flashattention_2: Correct bug reallocating V^t on every iteration
1 parent 2357ab3 commit 1df7825

File tree

9 files changed

+354
-34
lines changed

9 files changed

+354
-34
lines changed

sw/dnn/flashattention_2/scripts/datagen.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
format_struct_definition, format_array_definition, \
2424
format_array_declaration # noqa: E402
2525

26+
np.random.seed(42)
2627
np.random.seed(42)
2728
torch.manual_seed(42)
2829

@@ -92,6 +93,7 @@ def exact_golden_model(Q, K, V, B_r, B_c):
9293
def exact_flexfloat_golden_model(Q, K, V, B_r, B_c, desc):
9394
# Get layer dimensions
9495
N = Q.shape[0]
96+
d = Q.shape[1]
9597
# Calculate tiling parameters
9698
T_r = N // B_r
9799
T_c = N // B_c
@@ -111,15 +113,16 @@ def exact_flexfloat_golden_model(Q, K, V, B_r, B_c, desc):
111113
start_col = j * B_c
112114
end_col = start_col + B_c
113115
K_t_j = K_t[:, start_col:end_col]
114-
V_j = V[start_col:end_col, ]
116+
V_j = V[start_col:end_col,]
115117
# Compute O tile update
116118
S_ij = ff.array(np.zeros((B_r, B_c)), desc)
117119
S_ij = gemm.datagen.GemmDataGen().exact_golden_model(1, Q_i, K_t_j, 0, S_ij)
118120
m_i_prev = m_i
119121
m_i = np.maximum(m_i_prev, np.max(S_ij, 1, keepdims=True))
120122
shifted_exp = np.exp((m_i_prev - m_i).astype(np.float32))
121123
P_ij = np.exp((S_ij - m_i).astype(np.float32))
122-
PxV = gemm.datagen.GemmDataGen().exact_golden_model(1, P_ij, V_j, 0, S_ij)
124+
PxV = ff.array(np.zeros((B_r, d)), desc)
125+
PxV = gemm.datagen.GemmDataGen().exact_golden_model(1, P_ij, V_j, 0, PxV)
123126
row_sum = np.sum(P_ij.astype(np.float32), 1, keepdims=True)
124127
if j == 0:
125128
l_i = row_sum
@@ -144,6 +147,7 @@ def validate_config(N, d, B_r, B_c, dtype, baseline, gemm_impl):
144147
assert (N % B_c) == 0, 'N is not an integer multiple of B_c'
145148
assert (B_r % 8) == 0, 'B_r must be an integer multiple of the number of cores in a cluster'
146149
assert dtype != 'FP64', 'FP64 precision is not supported yet'
150+
assert dtype != 'FP64', 'FP64 precision is not supported yet'
147151

148152
# Q*K^t
149153
gemm.datagen.GemmDataGen().validate_config(
@@ -224,6 +228,8 @@ def emit_header(section, params):
224228
data_str += [format_array_definition(ctype, v_uid, V)]
225229
# result_def = format_array_definition(ctype, 'golden', output)
226230
# data_str += [format_ifdef_wrapper('BIST', result_def)]
231+
# result_def = format_array_definition(ctype, 'golden', output)
232+
# data_str += [format_ifdef_wrapper('BIST', result_def)]
227233
data_str = '\n\n'.join(data_str)
228234

229235
return data_str

sw/dnn/flashattention_2/src/flashattention_2_fp32.h

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,22 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
6161
tcdm_ptr += m_i_prev_size;
6262
float *l_i = tcdm_ptr;
6363
tcdm_ptr += l_i_size;
64+
65+
// allocate space for V^t when using optimized kernels
66+
float *V_t;
67+
if (!baseline) {
68+
V_t = tcdm_ptr;
69+
tcdm_ptr += B_c * d * sizeof(float);
70+
}
71+
6472
float shifted_exp;
6573
float row_sum;
6674

75+
snrt_mcycle();
76+
6777
// Iterate row blocks of Q
68-
uint32_t start_loop_outer = snrt_mcycle();
6978
for (int t_r = 0; t_r < T_r; t_r++) {
7079
// DMA copy Q row block to TCDM
71-
uint32_t start_dma = snrt_mcycle();
7280
if (snrt_is_dm_core()) {
7381
snrt_dma_load_2d_tile(Q_fa, // dst
7482
Q_l3, // src
@@ -81,10 +89,10 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
8189
);
8290
snrt_dma_wait_all();
8391
}
84-
uint32_t end_dma = snrt_mcycle();
85-
8692
snrt_cluster_hw_barrier();
8793

94+
snrt_mcycle();
95+
8896
// Initialize m_i, m_i_prev, l_i, row_sum
8997
uint32_t rows_per_core = B_r / num_cores;
9098
uint32_t start_row = rows_per_core * compute_id;
@@ -99,16 +107,12 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
99107

100108
snrt_cluster_hw_barrier();
101109

102-
snrt_cluster_hw_barrier();
110+
snrt_mcycle();
103111

104112
// Iterate column blocks of K (corresponding to row blocks of V)
105-
uint32_t start_loop_inner = snrt_mcycle();
106113
for (int t_c = 0; t_c < T_c; t_c++) {
107-
snrt_cluster_hw_barrier();
108-
109114
// DMA copy K column block (B_c, d) and V row block (B_c, d) to
110115
// TCDM. Both K and V are stored in (N, d) form in memory
111-
uint32_t start_dma = snrt_mcycle();
112116
if (!snrt_is_compute_core()) {
113117
snrt_dma_load_2d_tile(K_fa, // dst
114118
K_l3, // src
@@ -130,22 +134,22 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
130134
);
131135
snrt_dma_wait_all();
132136
}
133-
uint32_t end_dma = snrt_mcycle();
134-
135137
snrt_cluster_hw_barrier();
136138

139+
snrt_mcycle();
140+
137141
// Calculate O tile from Q, K and V tiles
138142
if (snrt_is_compute_core()) {
139143
// Matrix multiplication between row block of Q and transposed
140144
// column block of K to calculate a tile of S: S = Q * K^T.
141145
// The S tile is of form (B_r, B_c)
142-
uint32_t start_gemm = snrt_mcycle();
143146
sc_st_gemm(dtype, 1, 0, 1, B_r, B_c, d, 1, Q_fa, d, K_fa, d, 0,
144147
S_fa, B_c, gemm_implementation);
145-
uint32_t end_gemm = snrt_mcycle();
146148

147149
snrt_cluster_hw_barrier();
148150

151+
snrt_mcycle();
152+
149153
// Iterate over the rows of the S row block, distributing
150154
// the rows to the cores
151155
for (int row_idx = start_row; row_idx < end_row; row_idx++) {
@@ -188,7 +192,7 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
188192

189193
snrt_cluster_hw_barrier();
190194

191-
snrt_cluster_hw_barrier();
195+
snrt_mcycle();
192196

193197
// Calculate O tile (O_ij) of size (B_r, d).
194198
// The P tile is of size (B_r, B_c) and V of size (B_c, d)
@@ -207,10 +211,6 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
207211
// operation. We must transpose V in advance, so
208212
// we can compute P*(V^t)^t with the optimized GEMM.
209213

210-
// Allocate space for V^t
211-
float *V_t = tcdm_ptr;
212-
tcdm_ptr += B_c * d * sizeof(float);
213-
214214
// Compute V^t
215215
transpose_kernel(FP32, V_fa, V_t, B_c, d, baseline);
216216

@@ -225,19 +225,16 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
225225
sc_st_gemm(dtype, 0, 0, 1, B_r, d, B_c, 1, P_fa, B_c, V_t,
226226
B_c, beta, O_fa, d, gemm_implementation);
227227
}
228-
229-
uint32_t end_stats = snrt_mcycle();
230-
231-
snrt_cluster_hw_barrier();
232228
} else {
233229
snrt_cluster_hw_barrier();
234230
snrt_cluster_hw_barrier();
235-
snrt_cluster_hw_barrier();
236-
snrt_cluster_hw_barrier();
231+
snrt_mcycle();
232+
snrt_mcycle();
237233
}
238-
} // end of T_c loop
234+
snrt_cluster_hw_barrier();
239235

240-
snrt_cluster_hw_barrier();
236+
snrt_mcycle();
237+
} // end of T_c loop
241238

242239
// Rescaling for last t_c iteration
243240
// O_i = diag(l_i_Tc)^-1 * O_i
@@ -248,15 +245,12 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
248245
}
249246
}
250247
}
251-
252248
snrt_fpu_fence();
253-
254249
snrt_cluster_hw_barrier();
255250

256-
snrt_cluster_hw_barrier();
251+
snrt_mcycle();
257252

258253
// Write back O row block (B_r, d) to DRAM
259-
uint32_t start_dma_write_back = snrt_mcycle();
260254
if (snrt_is_dm_core()) {
261255
snrt_dma_store_2d_tile(O_l3, // dst
262256
O_fa, // src
@@ -269,10 +263,11 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
269263
);
270264
snrt_dma_wait_all();
271265
}
272-
uint32_t end_dma_write_back = snrt_mcycle();
266+
snrt_cluster_hw_barrier();
267+
268+
snrt_mcycle();
273269

274270
} // end of T_r loop
275-
uint32_t end_loop_outer = snrt_mcycle();
276271

277272
snrt_cluster_hw_barrier();
278273
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright 2023 ETH Zurich and University of Bologna.
2+
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
{
6+
N: 64,
7+
d: 128,
8+
B_r: 16,
9+
B_c: 64,
10+
dtype: "FP32",
11+
baseline: false
12+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright 2023 ETH Zurich and University of Bologna.
2+
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
{
6+
N: 64,
7+
d: 256,
8+
B_r: 16,
9+
B_c: 64,
10+
dtype: "FP32",
11+
baseline: false
12+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright 2023 ETH Zurich and University of Bologna.
2+
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
{
6+
N: 64,
7+
d: 64,
8+
B_r: 16,
9+
B_c: 64,
10+
dtype: "FP32",
11+
baseline: false
12+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright 2023 ETH Zurich and University of Bologna.
2+
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
{
6+
N: 64,
7+
d: 80,
8+
B_r: 16,
9+
B_c: 64,
10+
dtype: "FP32",
11+
baseline: false
12+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2024 ETH Zurich and University of Bologna.
3+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
# Luca Colagrande <colluca@iis.ee.ethz.ch>
7+
8+
from pathlib import Path
9+
import sys
10+
import subprocess
11+
import functools
12+
import json
13+
14+
sys.path.append(str(Path(__file__).parent / '../../../../../util/'))
15+
16+
ROI_SPEC = Path.cwd() / 'roi.json.tpl'
17+
MODELS = {
18+
'vit-base': {'N': 192, 'd': 64},
19+
'vit-large': {'N': 192, 'd': 64},
20+
'vit-huge': {'N': 192, 'd': 80},
21+
**{f'gpt-3-xl-forward-{N}': {'N': N, 'd': 128} for N in [128, 256, 512, 1024, 2048]},
22+
**{f'gpt-j-forward-{N}': {'N': N, 'd': 256} for N in [128, 256, 512, 1024, 2048]},
23+
}
24+
25+
26+
class Simulation():
27+
28+
def __init__(self, sim_dir):
29+
"""Initializes a simulation object from the run directory."""
30+
self.sim_dir = sim_dir
31+
32+
@functools.cached_property
33+
def performance_data(self):
34+
"""Returns all performance data logged during simulation."""
35+
roi_json = Path(self.sim_dir) / 'logs' / 'roi.json'
36+
with open(roi_json, 'r') as f:
37+
return json.load(f)
38+
39+
def get_metric(self, thread, region, metric, label_idx=0):
40+
"""Get a specific performance metric from a certain simulation run.
41+
42+
Args:
43+
data: All performance metric data as returned by
44+
`get_performance_data()`.
45+
thread: The thread to extract the metric from.
46+
region: The region to extract the metric from. Can be an integer
47+
index or the label assigned to the region. In case of multiple
48+
regions with the same label (as e.g. in a loop) you can get
49+
the n-th occurrence by passing a value to `label_idx`.
50+
metric: The name of the metric to extract.
51+
label_idx: See description for `region`.
52+
"""
53+
# Retrieve region index if supplied `region` argument is a region label.
54+
reg_idx = None
55+
if isinstance(region, str):
56+
cnt = 0
57+
for i, reg in enumerate(self.performance_data[thread]):
58+
if reg['label'] == region:
59+
if cnt == label_idx:
60+
reg_idx = i
61+
break
62+
else:
63+
cnt += 1
64+
elif isinstance(region, int):
65+
reg_idx = region
66+
else:
67+
raise ValueError('region argument must be of type int or str')
68+
# Get metric
69+
return self.performance_data[thread][reg_idx]['attrs'][metric]
70+
71+
def build_visual_trace(self):
72+
"""Build the visual trace of the simulation."""
73+
subprocess.run(['make', '-C', '../../../../../', 'visual-trace',
74+
f'SIM_DIR={self.sim_dir}',
75+
f'ROI_SPEC={ROI_SPEC}', '-j'], check=True)
76+
77+
78+
def load_simulation(model):
79+
"""Returns the simulation object for a given model."""
80+
return Simulation(Path.cwd() / f'runs/flashattention_2-fp32-opt-{model}')
81+
82+
83+
def get_total_runtime(sim, model):
84+
# Parameters
85+
N = MODELS[model]['N']
86+
Br = 16
87+
Bc = 64
88+
89+
# Derived parameters
90+
Tr = N / Br
91+
Tc = N / Bc
92+
93+
# Calculate total runtime
94+
tc_iter_time = sim.get_metric('hart_8', 'copy K & V', 'cycles') + \
95+
sim.get_metric('hart_0', 'QxKt', 'cycles') + \
96+
sim.get_metric('hart_0', 'softmax', 'cycles') + \
97+
sim.get_metric('hart_0', 'PxV', 'cycles')
98+
tc_loop_time = tc_iter_time * Tc
99+
tr_iter_time = sim.get_metric('hart_8', 'copy Q', 'cycles') + \
100+
sim.get_metric('hart_0', 'init', 'cycles') + \
101+
tc_loop_time + \
102+
sim.get_metric('hart_0', 'rescale', 'cycles') + \
103+
sim.get_metric('hart_0', 'rescale', 'cycles')
104+
total_time = tr_iter_time * Tr
105+
return total_time
106+
107+
108+
def main():
109+
110+
sim = load_simulation('vit-base')
111+
sim.build_visual_trace()
112+
113+
for model in MODELS:
114+
print(f'{model}:')
115+
total_time = get_total_runtime(sim, model)
116+
print(f'\tTotal time: {total_time / 10e9}s')
117+
118+
119+
if __name__ == '__main__':
120+
main()

0 commit comments

Comments
 (0)