@@ -43,14 +43,16 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
43
43
const bool return_softmax = params.p_ptr != nullptr ;
44
44
BOOL_SWITCH (is_even_MN, IsEvenMNConst, [&] {
45
45
BOOL_SWITCH (is_even_K, IsEvenKConst, [&] {
46
- BOOL_SWITCH (params.window_size_left >= 0 || params.window_size_right >= 0 , Is_local, [&] {
46
+ BOOL_SWITCH (( params.window_size_left >= 0 || params.window_size_right >= 0 ) && !Is_causal , Is_local, [&] {
47
47
BOOL_SWITCH (return_softmax, ReturnSoftmaxConst, [&] {
48
48
// Will only return softmax if dropout, to reduce compilation time.
49
49
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
50
50
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
51
51
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
52
52
// If Is_local, set Is_causal to false
53
- auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
53
+ auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
54
+ // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
55
+ // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
54
56
if (smem_size >= 48 * 1024 ) {
55
57
C10_CUDA_CHECK (cudaFuncSetAttribute (
56
58
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
@@ -79,13 +81,13 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
79
81
BOOL_SWITCH (params.is_causal , Is_causal, [&] {
80
82
BOOL_SWITCH (is_even_MN, IsEvenMNConst, [&] {
81
83
BOOL_SWITCH (is_even_K, IsEvenKConst, [&] {
82
- BOOL_SWITCH (params.window_size_left >= 0 || params.window_size_right >= 0 , Is_local, [&] {
84
+ BOOL_SWITCH (( params.window_size_left >= 0 || params.window_size_right >= 0 ) && !Is_causal , Is_local, [&] {
83
85
BOOL_SWITCH (params.num_splits > 1 , Split, [&] {
84
86
BOOL_SWITCH (params.knew_ptr != nullptr , Append_KV, [&] {
85
87
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
86
88
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
87
89
// If Is_local, set Is_causal to false
88
- auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal && !Is_local, Is_local , IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, Split, Append_KV>;
90
+ auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal , IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, Split, Append_KV>;
89
91
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
90
92
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
91
93
if (smem_size >= 48 * 1024 ) {
0 commit comments