Skip to content

Commit 8701243

Browse files
committed
move sparse_attn to new files
1 parent 0127b55 commit 8701243

37 files changed

+864
-846
lines changed

csrc/flash_attn/src/flash_fwd_kernel.h

Lines changed: 0 additions & 696 deletions
Large diffs are not rendered by default.

csrc/flash_attn/src/flash_fwd_launch_template.h

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,6 @@
2626
template<typename Kernel_traits, __VA_ARGS__> \
2727
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
2828

29-
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_sparse_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
30-
#if defined(ARCH_SUPPORTS_FLASH)
31-
static_assert(!(Is_causal && Is_local)); // Enforce constraints
32-
flash::compute_sparse_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
33-
#else
34-
FLASH_UNSUPPORTED_ARCH
35-
#endif
36-
}
37-
3829
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
3930
#if defined(ARCH_SUPPORTS_FLASH)
4031
static_assert(!(Is_causal && Is_local)); // Enforce constraints
@@ -57,50 +48,6 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L
5748
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
5849
}
5950

60-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
61-
void run_flash_sparse_fwd(Flash_fwd_params &params, cudaStream_t stream) {
62-
constexpr size_t smem_size = Kernel_traits::kSmemSize;
63-
// printf("smem_size = %d\n", smem_size);
64-
65-
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
66-
// https://github.com/kokkos/kokkos-kernels/issues/349
67-
// https://github.com/HazyResearch/flash-attention/issues/21
68-
69-
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
70-
dim3 grid(num_m_block, params.b, params.h);
71-
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
72-
const bool return_softmax = params.p_ptr != nullptr;
73-
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
74-
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
75-
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
76-
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
77-
constexpr bool IsEvenMNConst = false;
78-
constexpr bool Is_local = false;
79-
// Will only return softmax if dropout, to reduce compilation time.
80-
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
81-
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
82-
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
83-
// If Is_local, set Is_causal to false
84-
auto kernel = &flash_fwd_sparse_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
85-
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
86-
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
87-
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
88-
if (smem_size >= 48 * 1024) {
89-
C10_CUDA_CHECK(cudaFuncSetAttribute(
90-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
91-
}
92-
// int ctas_per_sm;
93-
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
94-
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
95-
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
96-
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
97-
C10_CUDA_KERNEL_LAUNCH_CHECK();
98-
});
99-
});
100-
});
101-
});
102-
}
103-
10451
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
10552
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
10653
constexpr size_t smem_size = Kernel_traits::kSmemSize;
@@ -407,67 +354,3 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
407354
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
408355
});
409356
}
410-
411-
template<typename T, bool Is_causal>
412-
void run_mha_fwd_sparse_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
413-
constexpr static int Headdim = 32;
414-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
415-
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
416-
});
417-
}
418-
419-
template<typename T, bool Is_causal>
420-
void run_mha_fwd_sparse_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
421-
constexpr static int Headdim = 64;
422-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
423-
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
424-
});
425-
}
426-
427-
template<typename T, bool Is_causal>
428-
void run_mha_fwd_sparse_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
429-
constexpr static int Headdim = 96;
430-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
431-
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
432-
});
433-
}
434-
435-
template<typename T, bool Is_causal>
436-
void run_mha_fwd_sparse_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
437-
constexpr static int Headdim = 128;
438-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
439-
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
440-
});
441-
}
442-
443-
template<typename T, bool Is_causal>
444-
void run_mha_fwd_sparse_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
445-
constexpr static int Headdim = 160;
446-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
447-
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
448-
});
449-
}
450-
451-
template<typename T, bool Is_causal>
452-
void run_mha_fwd_sparse_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
453-
constexpr static int Headdim = 192;
454-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
455-
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
456-
});
457-
}
458-
459-
template<typename T, bool Is_causal>
460-
void run_mha_fwd_sparse_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
461-
constexpr static int Headdim = 224;
462-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
463-
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
464-
});
465-
}
466-
467-
template<typename T, bool Is_causal>
468-
void run_mha_fwd_sparse_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
469-
constexpr static int Headdim = 256;
470-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
471-
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
472-
});
473-
}

csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim160_bf16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim160_bf16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim160_fp16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim160_fp16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim192_bf16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim192_bf16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim192_fp16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim192_fp16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim224_bf16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim224_bf16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim224_fp16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim224_fp16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim256_bf16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim256_bf16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim256_fp16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim256_fp16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim32_bf16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim32_bf16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim32_fp16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim32_fp16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim64_bf16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim64_bf16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim64_fp16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim64_fp16_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {

csrc/flash_attn/src/flash_fwd_sparse_hdim96_bf16_causal_sm80.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Splitting the different head dimensions to different files to speed up compilation.
33
// This file is auto-generated. See "generate_kernels.py"
44

5-
#include "flash_fwd_launch_template.h"
5+
#include "flash_fwd_sparse_launch_template.h"
66

77
template<>
88
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {

0 commit comments

Comments
 (0)