Skip to content

Commit 0c2fb25

Browse files
authored
Fix ima for split-kv kernel (#20)
1 parent 582eb8f commit 0c2fb25

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
209209
return 1;
210210
}
211211

212-
void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
212+
std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
213213
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
214214
const int head_size_rounded, const float p_dropout,
215215
const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
@@ -221,19 +221,24 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
221221
// In any case we don't expect seqlen_q to be larger than 64 for inference.
222222
const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
223223
params.num_splits = num_splits;
224+
at::Tensor softmax_lse_accum;
225+
at::Tensor out_accum;
226+
224227
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
225228
if (num_splits < 1) {
226229
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
227230
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
228231
}
229232
if (params.num_splits > 1) {
230-
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
231-
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
233+
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
234+
out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
232235
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
233236
params.oaccum_ptr = out_accum.data_ptr();
234237
}
235238
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
236239
}
240+
241+
return std::make_tuple(softmax_lse_accum, out_accum);
237242
}
238243

239244
void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
@@ -394,10 +399,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
394399
softcap
395400
);
396401

397-
398-
set_params_splitkv(params, batch_size, num_heads,
399-
head_size, seqlen_k, seqlen_q,
400-
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
402+
// Keep references to these tensors to extend their lifetime
403+
at::Tensor softmax_lse_accum, out_accum;
404+
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
405+
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
406+
head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
401407

402408
// number of times random will be generated per thread, to offset philox counter in thc random
403409
// state
@@ -642,11 +648,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
642648
params.v_batch_stride = v_padded.stride(0);
643649
}
644650
params.page_block_size = page_block_size;
651+
// Keep references to these tensors to extend their lifetime
652+
at::Tensor softmax_lse_accum, out_accum;
645653
if (seqlenq_ngroups_swapped) {
646654
// Only apply split-k for decoding
647-
set_params_splitkv(params, batch_size, num_heads,
648-
head_size, max_seqlen_k, max_seqlen_q,
649-
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
655+
std::tie(softmax_lse_accum, out_accum) =
656+
set_params_splitkv(params, batch_size, num_heads, head_size,
657+
max_seqlen_k, max_seqlen_q, head_size_rounded,
658+
p_dropout, /*num_splits*/ 0, dprops, opts);
650659
}
651660

652661
// number of times random will be generated per thread, to offset philox counter in thc random
@@ -936,9 +945,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
936945
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
937946
}
938947

939-
set_params_splitkv(params, batch_size, num_heads,
940-
head_size, seqlen_k, seqlen_q,
941-
head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
948+
// Keep references to these tensors to extend their lifetime
949+
at::Tensor softmax_lse_accum, out_accum;
950+
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
951+
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
952+
head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts);
942953

943954
if (paged_KV) {
944955
params.block_table = block_table.data_ptr<int>();

0 commit comments

Comments
 (0)