@@ -18,13 +18,13 @@ def timeit(fn, *args, **kwargs):
18
18
# Warmup
19
19
for _ in range (5 ):
20
20
fn (* args , ** kwargs )
21
-
21
+
22
22
# Benchmark using PyTorch Timer
23
23
t = benchmark .Timer (
24
24
stmt = 'fn(*args, **kwargs)' ,
25
25
globals = {'fn' : fn , 'args' : args , 'kwargs' : kwargs }
26
26
)
27
-
27
+
28
28
# Measure execution time
29
29
measurement = t .timeit (20 ) # Runs the function 20 times
30
30
# measurement = t.blocked_autorange(min_run_time=1)
@@ -38,14 +38,15 @@ def main():
38
38
).multi_processor_count
39
39
40
40
max_splits = 129
41
- check_all_splits = False
41
+ check_all_splits = True
42
42
43
43
causal = True
44
44
# causal = False
45
45
# dtype=torch.float16
46
46
dtype = torch .bfloat16
47
+ tp_degree = 1
47
48
48
- torch .manual_seed (42 )
49
+ torch .manual_seed (42 )
49
50
50
51
model_configs = [
51
52
# ("Gemma-2-2B", 8, 4, 256),
@@ -56,6 +57,7 @@ def main():
56
57
# ("Qwen-2.5-7B", 28, 4, 128),
57
58
# ("Llama-3.1-8B", 32, 8, 128),
58
59
("Llama-3.1-70B" , 64 , 8 , 128 ),
60
+ # ("Mistral Large", 96, 8, 128),
59
61
# ("Llama-3.1-405B", 128, 8, 128),
60
62
# ("Llama-3.2-1B", 32, 8, 64),
61
63
# ("Llama-3.2-3B", 24, 8, 128),
@@ -66,28 +68,32 @@ def main():
66
68
67
69
all_batch_configs .extend (itertools .product (
68
70
# [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen
69
- [4096 , 16384 , 65536 ], # context_seqlen
70
- # [131072], # context_seqlen
71
+ # [4096, 16384, 65536], # context_seqlen
72
+ [131072 ], # context_seqlen
71
73
# [i for i in range(1, (num_sms) + 1)], # num_requests
72
74
[1 , 4 , 8 , 16 ], # num_requests
73
75
# [1], # num_requests
74
- [1 , 4 , 8 , 16 ], # query_seqlen
75
- # [1], # query_seqlen
76
+ # [1, 4, 8, 16], # query_seqlen
77
+ [1 ], # query_seqlen
76
78
))
77
79
78
80
num_caches = max (reqs for _ , reqs , _ in all_batch_configs )
79
81
cache_seqlen = max (seqlen for seqlen , _ , _ in all_batch_configs )
80
82
81
83
for model_name , nheads_q , nheads_kv , headdim in model_configs :
84
+ assert nheads_kv % tp_degree == 0
85
+ print (f"***{ model_name } ***" )
86
+ print (f"QHEADS:{ nheads_q } , KVHEADS:{ nheads_kv } , HEADDIM:{ headdim } , TP:{ tp_degree } " )
87
+ nheads_q //= tp_degree
88
+ nheads_kv //= tp_degree
89
+
82
90
k_cache = torch .randn (
83
91
(num_caches , cache_seqlen , nheads_kv , headdim ), device = "cuda" , dtype = dtype
84
92
)
85
93
v_cache = torch .randn (
86
94
(num_caches , cache_seqlen , nheads_kv , headdim ), device = "cuda" , dtype = dtype
87
95
)
88
- print (f"***{ model_name } ***" )
89
- print (f"QHEADS:{ nheads_q } , KVHEADS:{ nheads_kv } , HEADDIM:{ headdim } " )
90
-
96
+
91
97
if check_all_splits is False :
92
98
print (f"{ 'CONTEXT' :<9} { 'BSZ' :<5} { 'QLEN' :<6} { 'FA2' :<10} { 'FA3' :<9} { 'RATIO' :<7} { 'GB/s' :<10} " )
93
99
@@ -139,7 +145,7 @@ def main():
139
145
cache_seqlens = cache_seqlens ,
140
146
cache_batch_idx = cache_idxs ,
141
147
causal = causal ,
142
- gqa_parallel = False ,
148
+ pack_gqa = False ,
143
149
num_splits = 1 ,
144
150
) * 1000. * 1000.
145
151
@@ -151,16 +157,16 @@ def main():
151
157
cache_seqlens = cache_seqlens ,
152
158
cache_batch_idx = cache_idxs ,
153
159
causal = causal ,
154
- gqa_parallel = True ,
160
+ pack_gqa = True ,
155
161
num_splits = 0 ,
156
- max_seqlen_k_hint = context_seqlen
162
+ # max_seqlen_k_hint=context_seqlen
157
163
) * 1000. * 1000.
158
164
159
165
if check_all_splits :
160
-
166
+
161
167
fa3_fastest_num_splits = 0
162
168
fa3_fastest_splitk_time = float ("inf" )
163
-
169
+
164
170
for num_splits in range (1 , max_splits ):
165
171
t = timeit (
166
172
flash_attn_interface .flash_attn_with_kvcache ,
@@ -170,7 +176,7 @@ def main():
170
176
cache_seqlens = cache_seqlens ,
171
177
cache_batch_idx = cache_idxs ,
172
178
causal = causal ,
173
- gqa_parallel = False ,
179
+ pack_gqa = False ,
174
180
num_splits = num_splits
175
181
) * 1000. * 1000.
176
182
@@ -181,7 +187,7 @@ def main():
181
187
cache_seqlens = cache_seqlens ,
182
188
cache_batch_idx = cache_idxs ,
183
189
causal = causal ,
184
- gqa_parallel = False ,
190
+ pack_gqa = False ,
185
191
num_splits = num_splits
186
192
)
187
193
@@ -192,7 +198,7 @@ def main():
192
198
cache_seqlens = cache_seqlens ,
193
199
cache_batch_idx = cache_idxs ,
194
200
causal = causal ,
195
- gqa_parallel = False ,
201
+ pack_gqa = False ,
196
202
num_splits = 1
197
203
)
198
204
@@ -220,7 +226,7 @@ def main():
220
226
cache_seqlens = cache_seqlens ,
221
227
cache_batch_idx = cache_idxs ,
222
228
causal = causal ,
223
- gqa_parallel = True ,
229
+ pack_gqa = True ,
224
230
num_splits = num_splits
225
231
) * 1000. * 1000.
226
232
@@ -231,7 +237,7 @@ def main():
231
237
cache_seqlens = cache_seqlens ,
232
238
cache_batch_idx = cache_idxs ,
233
239
causal = causal ,
234
- gqa_parallel = True ,
240
+ pack_gqa = True ,
235
241
num_splits = num_splits
236
242
)
237
243
@@ -242,7 +248,7 @@ def main():
242
248
cache_seqlens = cache_seqlens ,
243
249
cache_batch_idx = cache_idxs ,
244
250
causal = causal ,
245
- gqa_parallel = True ,
251
+ pack_gqa = True ,
246
252
num_splits = 1
247
253
)
248
254
@@ -257,7 +263,7 @@ def main():
257
263
if t < fa3_fastest_splitk_time_gqa :
258
264
fa3_fastest_splitk_time_gqa = t
259
265
fa3_fastest_num_splits_gqa = num_splits
260
-
266
+
261
267
efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa )/ num_sms
262
268
heuristic_ratio = fa3_time_gqa_heuristic / fa3_fastest_splitk_time_gqa
263
269
# remeasure to smooth anomalies
@@ -271,11 +277,11 @@ def main():
271
277
cache_seqlens = cache_seqlens ,
272
278
cache_batch_idx = cache_idxs ,
273
279
causal = causal ,
274
- gqa_parallel = True ,
280
+ pack_gqa = True ,
275
281
# num_splits=num_splits_select,
276
282
# num_splits=1,
277
283
num_splits = 0 ,
278
- max_seqlen_k_hint = context_seqlen
284
+ # max_seqlen_k_hint=context_seqlen
279
285
) * 1000. * 1000.
280
286
281
287
fa3_fastest_splitk_time_gqa = timeit (
@@ -286,9 +292,9 @@ def main():
286
292
cache_seqlens = cache_seqlens ,
287
293
cache_batch_idx = cache_idxs ,
288
294
causal = causal ,
289
- gqa_parallel = True ,
295
+ pack_gqa = True ,
290
296
num_splits = fa3_fastest_num_splits_gqa
291
- ) * 1000. * 1000.
297
+ ) * 1000. * 1000.
292
298
293
299
if check_all_splits is True :
294
300
print (
@@ -308,7 +314,7 @@ def main():
308
314
# f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, "
309
315
f"RATIO:{ fa3_time_gqa_heuristic / fa3_fastest_splitk_time_gqa :.2f} , "
310
316
f"EFF:{ efficiency :.2f} , "
311
- f"GB/s:{ bytes_kv / fa3_time_gqa_heuristic * 1e-3 :.2f} "
317
+ f"GB/s:{ bytes_kv / fa3_time_gqa_heuristic * 1e-3 :.2f} "
312
318
)
313
319
314
320
if check_all_splits is False :
@@ -322,4 +328,4 @@ def main():
322
328
323
329
324
330
if __name__ == "__main__" :
325
- main ()
331
+ main ()
0 commit comments