@@ -209,7 +209,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
209
209
return 1 ;
210
210
}
211
211
212
- void set_params_splitkv (Flash_fwd_params ¶ms, const int batch_size,
212
+ std::tuple<at::Tensor, at::Tensor> set_params_splitkv (Flash_fwd_params ¶ms, const int batch_size,
213
213
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
214
214
const int head_size_rounded, const float p_dropout,
215
215
const int num_splits, cudaDeviceProp *dprops, struct c10 ::TensorOptions opts) {
@@ -221,19 +221,24 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
221
221
// In any case we don't expect seqlen_q to be larger than 64 for inference.
222
222
const int num_m_blocks = (max_seqlen_q + 64 - 1 ) / 64 ;
223
223
params.num_splits = num_splits;
224
+ at::Tensor softmax_lse_accum;
225
+ at::Tensor out_accum;
226
+
224
227
if (p_dropout == 0 .0f ) { // SplitKV is not implemented for dropout
225
228
if (num_splits < 1 ) {
226
229
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
227
230
params.num_splits = num_splits_heuristic (batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2 , num_n_blocks, 128 );
228
231
}
229
232
if (params.num_splits > 1 ) {
230
- at::Tensor softmax_lse_accum = torch::empty ({params.num_splits , batch_size, num_heads, max_seqlen_q}, opts.dtype (at::kFloat ));
231
- at::Tensor out_accum = torch::empty ({params.num_splits , batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype (at::kFloat ));
233
+ softmax_lse_accum = torch::empty ({params.num_splits , batch_size, num_heads, max_seqlen_q}, opts.dtype (at::kFloat ));
234
+ out_accum = torch::empty ({params.num_splits , batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype (at::kFloat ));
232
235
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr ();
233
236
params.oaccum_ptr = out_accum.data_ptr ();
234
237
}
235
238
TORCH_CHECK (params.num_splits <= 128 , " num_splits > 128 not supported" );
236
239
}
240
+
241
+ return std::make_tuple (softmax_lse_accum, out_accum);
237
242
}
238
243
239
244
void set_params_alibi (Flash_fwd_params ¶ms, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
@@ -394,10 +399,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
394
399
softcap
395
400
);
396
401
397
-
398
- set_params_splitkv (params, batch_size, num_heads,
399
- head_size, seqlen_k, seqlen_q,
400
- head_size_rounded, p_dropout, /* num_splits*/ 0 , dprops, opts);
402
+ // Keep references to these tensors to extend their lifetime
403
+ at::Tensor softmax_lse_accum, out_accum;
404
+ std::tie (softmax_lse_accum, out_accum) = set_params_splitkv (
405
+ params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
406
+ head_size_rounded, p_dropout, /* num_splits*/ 0 , dprops, opts);
401
407
402
408
// number of times random will be generated per thread, to offset philox counter in thc random
403
409
// state
@@ -642,11 +648,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
642
648
params.v_batch_stride = v_padded.stride (0 );
643
649
}
644
650
params.page_block_size = page_block_size;
651
+ // Keep references to these tensors to extend their lifetime
652
+ at::Tensor softmax_lse_accum, out_accum;
645
653
if (seqlenq_ngroups_swapped) {
646
654
// Only apply split-k for decoding
647
- set_params_splitkv (params, batch_size, num_heads,
648
- head_size, max_seqlen_k, max_seqlen_q,
649
- head_size_rounded, p_dropout, /* num_splits*/ 0 , dprops, opts);
655
+ std::tie (softmax_lse_accum, out_accum) =
656
+ set_params_splitkv (params, batch_size, num_heads, head_size,
657
+ max_seqlen_k, max_seqlen_q, head_size_rounded,
658
+ p_dropout, /* num_splits*/ 0 , dprops, opts);
650
659
}
651
660
652
661
// number of times random will be generated per thread, to offset philox counter in thc random
@@ -936,9 +945,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
936
945
params.cache_batch_idx = reinterpret_cast <int *>(cache_batch_idx.data_ptr ());
937
946
}
938
947
939
- set_params_splitkv (params, batch_size, num_heads,
940
- head_size, seqlen_k, seqlen_q,
941
- head_size_rounded, /* dropout*/ 0 .f , num_splits, dprops, opts);
948
+ // Keep references to these tensors to extend their lifetime
949
+ at::Tensor softmax_lse_accum, out_accum;
950
+ std::tie (softmax_lse_accum, out_accum) = set_params_splitkv (
951
+ params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
952
+ head_size_rounded, /* dropout*/ 0 .f , num_splits, dprops, opts);
942
953
943
954
if (paged_KV) {
944
955
params.block_table = block_table.data_ptr <int >();
0 commit comments