Skip to content

Commit d97720a

Browse files
Sync up to 06e34f6
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 175ebb2 commit d97720a

24 files changed

+350
-249
lines changed

hopper/benchmark_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def run(*args, **kwargs):
355355
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
356356
# pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
357357
else:
358-
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
358+
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
359359
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
360360
time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
361361
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
@@ -364,7 +364,7 @@ def run(*args, **kwargs):
364364
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
365365
repeats=repeats, verbose=False, desc='Fav3')
366366
else:
367-
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
367+
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
368368
repeats=repeats, verbose=False, desc='Fav3')
369369
time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean
370370
# time.sleep(1)

hopper/benchmark_split_kv.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def timeit(fn, *args, **kwargs):
1818
# Warmup
1919
for _ in range(5):
2020
fn(*args, **kwargs)
21-
21+
2222
# Benchmark using PyTorch Timer
2323
t = benchmark.Timer(
2424
stmt='fn(*args, **kwargs)',
2525
globals={'fn': fn, 'args': args, 'kwargs': kwargs}
2626
)
27-
27+
2828
# Measure execution time
2929
measurement = t.timeit(20) # Runs the function 20 times
3030
# measurement = t.blocked_autorange(min_run_time=1)
@@ -38,14 +38,15 @@ def main():
3838
).multi_processor_count
3939

4040
max_splits = 129
41-
check_all_splits = False
41+
check_all_splits = True
4242

4343
causal = True
4444
# causal = False
4545
# dtype=torch.float16
4646
dtype=torch.bfloat16
47+
tp_degree = 1
4748

48-
torch.manual_seed(42)
49+
torch.manual_seed(42)
4950

5051
model_configs = [
5152
# ("Gemma-2-2B", 8, 4, 256),
@@ -56,6 +57,7 @@ def main():
5657
# ("Qwen-2.5-7B", 28, 4, 128),
5758
# ("Llama-3.1-8B", 32, 8, 128),
5859
("Llama-3.1-70B", 64, 8, 128),
60+
# ("Mistral Large", 96, 8, 128),
5961
# ("Llama-3.1-405B", 128, 8, 128),
6062
# ("Llama-3.2-1B", 32, 8, 64),
6163
# ("Llama-3.2-3B", 24, 8, 128),
@@ -66,28 +68,32 @@ def main():
6668

6769
all_batch_configs.extend(itertools.product(
6870
# [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen
69-
[4096, 16384, 65536], # context_seqlen
70-
# [131072], # context_seqlen
71+
# [4096, 16384, 65536], # context_seqlen
72+
[131072], # context_seqlen
7173
# [i for i in range(1, (num_sms) + 1)], # num_requests
7274
[1, 4, 8, 16], # num_requests
7375
# [1], # num_requests
74-
[1, 4, 8, 16], # query_seqlen
75-
# [1], # query_seqlen
76+
# [1, 4, 8, 16], # query_seqlen
77+
[1], # query_seqlen
7678
))
7779

7880
num_caches = max(reqs for _, reqs, _ in all_batch_configs)
7981
cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs)
8082

8183
for model_name, nheads_q, nheads_kv, headdim in model_configs:
84+
assert nheads_kv % tp_degree == 0
85+
print(f"***{model_name}***")
86+
print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}, TP:{tp_degree}")
87+
nheads_q //= tp_degree
88+
nheads_kv //= tp_degree
89+
8290
k_cache = torch.randn(
8391
(num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
8492
)
8593
v_cache = torch.randn(
8694
(num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
8795
)
88-
print(f"***{model_name}***")
89-
print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}")
90-
96+
9197
if check_all_splits is False:
9298
print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}")
9399

@@ -139,7 +145,7 @@ def main():
139145
cache_seqlens=cache_seqlens,
140146
cache_batch_idx=cache_idxs,
141147
causal=causal,
142-
gqa_parallel=False,
148+
pack_gqa=False,
143149
num_splits=1,
144150
) * 1000. * 1000.
145151

@@ -151,16 +157,16 @@ def main():
151157
cache_seqlens=cache_seqlens,
152158
cache_batch_idx=cache_idxs,
153159
causal=causal,
154-
gqa_parallel=True,
160+
pack_gqa=True,
155161
num_splits=0,
156-
max_seqlen_k_hint=context_seqlen
162+
# max_seqlen_k_hint=context_seqlen
157163
) * 1000. * 1000.
158164

159165
if check_all_splits:
160-
166+
161167
fa3_fastest_num_splits = 0
162168
fa3_fastest_splitk_time = float("inf")
163-
169+
164170
for num_splits in range(1, max_splits):
165171
t = timeit(
166172
flash_attn_interface.flash_attn_with_kvcache,
@@ -170,7 +176,7 @@ def main():
170176
cache_seqlens=cache_seqlens,
171177
cache_batch_idx=cache_idxs,
172178
causal=causal,
173-
gqa_parallel=False,
179+
pack_gqa=False,
174180
num_splits=num_splits
175181
) * 1000. * 1000.
176182

@@ -181,7 +187,7 @@ def main():
181187
cache_seqlens=cache_seqlens,
182188
cache_batch_idx=cache_idxs,
183189
causal=causal,
184-
gqa_parallel=False,
190+
pack_gqa=False,
185191
num_splits=num_splits
186192
)
187193

@@ -192,7 +198,7 @@ def main():
192198
cache_seqlens=cache_seqlens,
193199
cache_batch_idx=cache_idxs,
194200
causal=causal,
195-
gqa_parallel=False,
201+
pack_gqa=False,
196202
num_splits=1
197203
)
198204

@@ -220,7 +226,7 @@ def main():
220226
cache_seqlens=cache_seqlens,
221227
cache_batch_idx=cache_idxs,
222228
causal=causal,
223-
gqa_parallel=True,
229+
pack_gqa=True,
224230
num_splits=num_splits
225231
) * 1000. * 1000.
226232

@@ -231,7 +237,7 @@ def main():
231237
cache_seqlens=cache_seqlens,
232238
cache_batch_idx=cache_idxs,
233239
causal=causal,
234-
gqa_parallel=True,
240+
pack_gqa=True,
235241
num_splits=num_splits
236242
)
237243

@@ -242,7 +248,7 @@ def main():
242248
cache_seqlens=cache_seqlens,
243249
cache_batch_idx=cache_idxs,
244250
causal=causal,
245-
gqa_parallel=True,
251+
pack_gqa=True,
246252
num_splits=1
247253
)
248254

@@ -257,7 +263,7 @@ def main():
257263
if t < fa3_fastest_splitk_time_gqa:
258264
fa3_fastest_splitk_time_gqa = t
259265
fa3_fastest_num_splits_gqa = num_splits
260-
266+
261267
efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms
262268
heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa
263269
# remeasure to smooth anomalies
@@ -271,11 +277,11 @@ def main():
271277
cache_seqlens=cache_seqlens,
272278
cache_batch_idx=cache_idxs,
273279
causal=causal,
274-
gqa_parallel=True,
280+
pack_gqa=True,
275281
# num_splits=num_splits_select,
276282
# num_splits=1,
277283
num_splits=0,
278-
max_seqlen_k_hint=context_seqlen
284+
# max_seqlen_k_hint=context_seqlen
279285
) * 1000. * 1000.
280286

281287
fa3_fastest_splitk_time_gqa = timeit(
@@ -286,9 +292,9 @@ def main():
286292
cache_seqlens=cache_seqlens,
287293
cache_batch_idx=cache_idxs,
288294
causal=causal,
289-
gqa_parallel=True,
295+
pack_gqa=True,
290296
num_splits=fa3_fastest_num_splits_gqa
291-
) * 1000. * 1000.
297+
) * 1000. * 1000.
292298

293299
if check_all_splits is True:
294300
print(
@@ -308,7 +314,7 @@ def main():
308314
# f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, "
309315
f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, "
310316
f"EFF:{efficiency:.2f}, "
311-
f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}"
317+
f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}"
312318
)
313319

314320
if check_all_splits is False:
@@ -322,4 +328,4 @@ def main():
322328

323329

324330
if __name__ == "__main__":
325-
main()
331+
main()

hopper/block.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/******************************************************************************
2+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3+
******************************************************************************/
4+
5+
#pragma once
6+
7+
namespace flash {
8+
9+
template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
10+
struct BlockMN {
11+
12+
static
13+
CUTLASS_DEVICE
14+
cute::tuple<int, int> get_n_block_min_max(
15+
SeqlenInfo_t const& seqlen_info,
16+
int const m_block, int const bidb, int const split_idx, int const num_splits,
17+
int const window_size_left, int const window_size_right,
18+
cutlass::FastDivmod const& qhead_per_khead_divmod) {
19+
20+
int const seqlen_k = seqlen_info.seqlen_k;
21+
int const seqlen_q = seqlen_info.seqlen_q;
22+
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
23+
if constexpr (Is_causal || Is_local) {
24+
int m_idx_max = (m_block + 1) * kBlockM;
25+
// TODO: check off-by-1 error
26+
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
27+
n_block_max = std::min(n_block_max,
28+
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
29+
}
30+
int n_block_min = 0;
31+
if constexpr (Is_local) {
32+
int m_idx_min = m_block * kBlockM;
33+
if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
34+
n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN);
35+
}
36+
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
37+
if constexpr (Split) {
38+
int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits);
39+
n_block_min = n_block_min + split_idx * num_n_blocks_per_split;
40+
n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
41+
}
42+
// if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
43+
return {n_block_min, n_block_max};
44+
}
45+
46+
static
47+
CUTLASS_DEVICE
48+
cute::tuple<int, int> get_n_block_k_new_min_max(
49+
SeqlenInfo_t const& seqlen_info,
50+
int const m_block, int const bidb, int const split_idx, int const num_splits,
51+
int const window_size_left, int const window_size_right,
52+
cutlass::FastDivmod const& qhead_per_khead_divmod) {
53+
54+
auto [n_block_min, n_block_max] = get_n_block_min_max(
55+
seqlen_info, m_block, bidb, split_idx, num_splits,
56+
window_size_left, window_size_right, qhead_per_khead_divmod);
57+
int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
58+
int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
59+
int const n_block_new_min = idx_k_new_min / kBlockN;
60+
int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
61+
// if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
62+
return {n_block_new_min, n_block_new_max};
63+
}
64+
65+
static
66+
CUTLASS_DEVICE
67+
cute::tuple<int, int> get_m_block_min_max(
68+
SeqlenInfo_t const& seqlen_info,
69+
int const n_block, int const bidb,
70+
int const window_size_left, int const window_size_right, int const sink_token_length) {
71+
72+
int const seqlen_q = seqlen_info.seqlen_q;
73+
int const seqlen_k = seqlen_info.seqlen_k;
74+
int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
75+
if constexpr (Is_local) {
76+
if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
77+
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
78+
}
79+
}
80+
int m_block_min = 0;
81+
if constexpr (Is_causal || Is_local) {
82+
m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
83+
}
84+
return {m_block_min, m_block_max};
85+
}
86+
87+
};
88+
89+
} // namespace flash

hopper/epilogue_bwd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ struct CollectiveEpilogueBwd {
238238
Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
239239
Tensor tdKVrdV = make_fragment_like(tdKVgdV);
240240
Tensor tdKVrdK = make_fragment_like(tdKVgdK);
241-
Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
241+
Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
242242
// Repeat the partitioning with identity layouts
243243
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
244244
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));

0 commit comments

Comments
 (0)