diff --git a/candle-flash-attn/kernels/alibi.h b/candle-flash-attn/kernels/alibi.h new file mode 100644 index 0000000000..1afb3687d3 --- /dev/null +++ b/candle-flash-attn/kernels/alibi.h @@ -0,0 +1,62 @@ +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void apply_alibi(Tensor &tensor, + const int col_idx_offset_, + const int max_seqlen_k, + const int row_idx_offset, + const int max_seqlen_q, + const int warp_row_stride, + const float alibi_slope) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } +} + +} // namespace flash diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h index 94251a41e4..65435e51a2 100644 --- a/candle-flash-attn/kernels/block_info.h +++ b/candle-flash-attn/kernels/block_info.h @@ -14,9 +14,12 @@ struct BlockInfo { template __device__ BlockInfo(const Params ¶ms, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } @@ -32,8 +35,10 @@ struct BlockInfo { const int sum_s_q; const int sum_s_k; - const uint32_t actual_seqlen_q; - const uint32_t actual_seqlen_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index be4ae0ca8f..80b517e901 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,15 +7,6 @@ #include #include -// #ifdef OLD_GENERATOR_PATH -// #include -// #else -// #include -// #endif -// -// #include - - constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params { // The O matrix (output). void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; @@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params { // The pointer to the softmax sum. void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; // The scaling factors for the kernel. float scale_softmax; @@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + int *__restrict__ blockmask; + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int *__restrict__ cache_batch_idx; + // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; @@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params { float rp_dropout; float scale_softmax_rp_dropout; - // Random state. - // at::PhiloxCudaState philox_args; + // Local window size + int window_size_left, window_size_right; bool is_bf16; bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params { // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 72991257aa..8113dbc742 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -1,17 +1,15 @@ #include "flash_fwd_launch_template.h" -// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { -// FWD_HEADDIM_SWITCH(params.d, [&] { -// run_mha_fwd_(params, stream); -// }); -// } - -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - FP16_SWITCH(!params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_(params, stream); - }); - }); +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { +// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); +// } else { +// run_mha_fwd_splitkv_dispatch(params, stream); +// } + }); + }); } extern "C" void run_mha( @@ -20,6 +18,7 @@ extern "C" void run_mha( void *v_ptr, void *o_ptr, void *softmax_lse_ptr, + void *alibi_slopes_ptr, int32_t *cu_seqlens_q_ptr, int32_t *cu_seqlens_k_ptr, @@ -28,6 +27,7 @@ extern "C" void run_mha( uint32_t k_batch_stride, uint32_t v_batch_stride, uint32_t o_batch_stride, + uint32_t alibi_slopes_batch_stride, uint32_t q_row_stride, uint32_t k_row_stride, @@ -51,8 +51,11 @@ extern "C" void run_mha( uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, + int is_bf16, int is_causal, - int is_bf16 + + int window_size_left, + int window_size_right ) { Flash_fwd_params params; // Reset the parameters @@ -65,12 +68,14 @@ extern "C" void run_mha( params.o_ptr = o_ptr; params.softmax_lse_ptr = softmax_lse_ptr; + params.alibi_slopes_ptr = alibi_slopes_ptr; // All stride are in elements, not bytes. params.q_batch_stride = q_batch_stride; params.k_batch_stride = k_batch_stride; params.v_batch_stride = v_batch_stride; params.o_batch_stride = o_batch_stride; + params.alibi_slopes_batch_stride = alibi_slopes_batch_stride; params.q_row_stride = q_row_stride; params.k_row_stride = k_row_stride; @@ -92,7 +97,6 @@ extern "C" void run_mha( params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; - params.is_causal = is_causal; // Set the different scale values. params.scale_softmax = softmax_scale; @@ -106,6 +110,14 @@ extern "C" void run_mha( params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = nullptr; // used for `return_softmax`. + params.seqused_k = nullptr; + + params.is_causal = is_causal; + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; + params.num_splits = 1; cudaStream_t stream = 0; // Use the default stream. run_mha_fwd(params, stream); diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu index 654400a749..6ffa4126e5 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -1,19 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// if (params.p_dropout == 1.f) { -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu index 5b7254a918..19b005ad99 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -1,32 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// if (params.p_dropout == 1.f) { -// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k -// run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// // 1st ones are good for H100, A100 -// // 2nd one is good for A6000 bc we get slightly better occupancy -// } else { -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// // 1st one is good for H100, A100, A6000 -// } -// } - template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu index 6a9d60c391..f674f48185 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu @@ -1,17 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu index 6c40a164d6..afd0a8a387 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu @@ -1,27 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. -// // For A100, H100, 1st is fastest. -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu index d2f4cba715..aa91bdd66d 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -1,16 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu index 2875c92660..37a965264a 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -1,27 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // This one is slightly faster for causal? -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// }); -// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout -// // For A6000, 1st is faster when causal, 3rd is faster when not causal -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu index 982fe7eade..167a0df2b0 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim224(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu index 4c083f7b66..58ffe75c30 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim224(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu index cb074a95ed..1b37014154 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu index ddf5e13229..9f35129c3a 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -1,9 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu index 81e359e16f..770de6fcf2 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -1,10 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index 91e6331e90..8dbf8b94ae 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -1,23 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// // For dropout there might be a lot of register spilling? -// // These two are very slow due to register spilling -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // This one is slightly slower -// // run_flash_fwd>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu index fffcbebb5d..22eac8789d 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -1,19 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// if (params.p_dropout == 1.f) { -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu index 01bd171672..e6da5dd2d8 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -1,26 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// if (params.p_dropout == 1.f) { -// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower -// // Using block size (64 x 256) is 27% slower for seqlen=2k -// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu index b0b27db596..9c003540cd 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -1,17 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu index 820b63cbbf..8108696a0e 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -1,23 +1,10 @@ // Copyright (c) 2023, Tri Dao. - // Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // This 3rd one is good for H100, and A100, A6000 -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // These two are always slower -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// }); -// } -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); -} \ No newline at end of file +} diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h index 232dea0dbc..05f5f70126 100644 --- a/candle-flash-attn/kernels/flash_fwd_kernel.h +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -4,20 +4,18 @@ #pragma once -#include #include -#include #include #include #include -#include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" -#include "philox.cuh" + +#include "alibi.h" namespace flash { @@ -25,49 +23,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - // TODO: Shouldn't this be size<1>? - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, Tensor2 &acc_o, float softmax_scale_log2) { @@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T flash::reduce_sum(scores, scores_sum); } else { Tensor scores_max_prev = make_fragment_like(scores_max); - copy(scores_max, scores_max_prev); + cute::copy(scores_max, scores_max_prev); flash::template reduce_max(scores, scores_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); @@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T template inline __device__ void write_softmax_to_gmem( - Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_thr_copy_P + Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_tiled_copy_P ) { // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) Layout l = tOrP.layout(); Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); - // TODO(laurent): reactivate the following - // CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); + CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); #pragma unroll for (int mi = 0; mi < size<1>(tPrP); ++mi) { - copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. +// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { +// auto seeds = at::cuda::philox::unpack(params.philox_args); +// params.rng_state[0] = std::get<0>(seeds); +// params.rng_state[1] = std::get<1>(seeds); +// params.rng_state[0] = 0; +// params.rng_state[1] = 0; +// } + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse @@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); - auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Copy Atom retiling // - auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); - // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} - auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); - auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // TODO: this might need to change if we change the mma instruction in SM70 @@ -268,8 +275,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } // // Copy rmem to smem @@ -285,14 +292,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // __syncthreads(); @@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } // auto seeds = at::cuda::philox::unpack(params.philox_args); @@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); + float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. - constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( - gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } + // if (cute::thread0()) { print_tensor(scores); } // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { - if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + + if (Has_alibi) { + flash::apply_alibi( + scores, + n_block * kBlockN, + binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16, + alibi_slope + ); + } + + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) @@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Idk why it's get<1> and not get<0> of the stride. // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16 + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16 + ); + // if (cute::thread0()) { print_tensor(scores); } } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); - copy(tOrP, tOrP_copy); + cute::copy(tOrP, tOrP_copy); flash::apply_dropout( tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { @@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // if (cute::thread0()) { print(tOrP); } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K ); flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + if (Has_alibi) { + flash::apply_alibi( + scores, + n_block * kBlockN, + binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16, + alibi_slope + ); + } + + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } + + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); - copy(tOrP, tOrP_copy); + cute::copy(tOrP, tOrP_copy); flash::apply_dropout( tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { @@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi block_row_idx, block_col_idx, kNWarps); } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor rO = flash::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning - auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); - // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sO has the same size as sQ, so we don't need to sync here. if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } - copy(smem_thr_copy_O, taccOrO, taccOsO); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; @@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); __syncthreads(); Tensor tOrO = make_tensor(shape(tOgO)); - copy(gmem_thr_copy_O, tOsO, tOrO); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) @@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } + //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 398ce0779d..66ab6206db 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -4,15 +4,14 @@ #pragma once -// #include - #include "static_switch.h" #include "flash.h" #include "flash_fwd_kernel.h" -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { - flash::compute_attn(params); + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + flash::compute_attn(params); } template @@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.b, params.h); - // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check - // for cu_seqlens_q as well. - const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; - BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // if (smem_size >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // } - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - // C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); }); }); }); } + template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { @@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 96; + constexpr static int Headdim = 96; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 128; + constexpr static int Headdim = 128; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 160; + constexpr static int Headdim = 160; // auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { @@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 192; + constexpr static int Headdim = 192; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { @@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 224; + constexpr static int Headdim = 224; int device; cudaGetDevice(&device); int max_smem_per_block; @@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 256; + constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_sm, max_smem_per_block; diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index 3468e4bffc..f000ff24da 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); using SmemLayoutVtransposed = decltype(tile_to_shape( SmemLayoutAtomVtransposed{}, Shape, Int>{})); // Maybe the VtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutAtomO = decltype( composition(Swizzle{}, @@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; @@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base { DefaultCopy >; using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; @@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>; using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtomP{}, Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. @@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); using SmemLayoutKtransposed = decltype(tile_to_shape( SmemLayoutAtomKtransposed{}, make_shape(Int{}, Int{}))); // Maybe the KtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 @@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); using SmemLayoutPdStransposed = decltype(tile_to_shape( SmemLayoutAtomPdStransposed{}, make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemCopyAtomPdS = Copy_Atom; + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); + composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); using SmemLayoutQdOtransposed = decltype(tile_to_shape( SmemLayoutAtomQdOtransposed{}, make_shape(Int{}, Int{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, @@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemdPsumCount = kBlockM; static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) diff --git a/candle-flash-attn/kernels/kernel_traits_sm90.h b/candle-flash-attn/kernels/kernel_traits_sm90.h new file mode 100644 index 0000000000..e07f383904 --- /dev/null +++ b/candle-flash-attn/kernels/kernel_traits_sm90.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits_sm90 { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h index 3e9a7b4597..09a93f145d 100644 --- a/candle-flash-attn/kernels/softmax.h +++ b/candle-flash-attn/kernels/softmax.h @@ -8,8 +8,7 @@ #include -#include -#include +#include #include "philox.cuh" #include "utils.h" @@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tens } template -inline __device__ void apply_mask(Tensor &tensor, const uint32_t max_seqlen_k) { +inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; + const int col_idx = col_idx_base + j; if (col_idx >= max_seqlen_k) { // Without the "make_coord" we get wrong results #pragma unroll @@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor &tensor, const uint32_t } } -template -inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, - const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, - const uint32_t warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; - // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; - const uint32_t row_idx_offset = row_idx_offset_; - const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { - const uint32_t row_idx = row_idx_base + i * 8; - const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const uint32_t col_idx_base = col_idx_offset + nj * 8; + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const u } } +template +inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, - const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); @@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx( CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { @@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx( template inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, unsigned long long seed, unsigned long long offset, - uint32_t block_row_start, uint32_t block_col_start, - uint32_t block_row_stride) { + int block_row_start, int block_col_start, + int block_row_stride) { // tensor has shape (8, MMA_M, MMA_N / 2) using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h index 2221a2faf3..6fb39dc473 100644 --- a/candle-flash-attn/kernels/utils.h +++ b/candle-flash-attn/kernels/utils.h @@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2(const float2 x) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ float2 half2_unpack(uint32_t a); - -template <> -inline __device__ float2 half2_unpack<__half>(uint32_t a) { - return __half22float2(reinterpret_cast<__half2 (&)>(a)); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template <> -inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { - return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert two half2's or bf162's into float, then take their dot product. -template -inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { - float2 af = flash::half2_unpack(a); - float2 bf = flash::half2_unpack(b); - return af.x * bf.x + af.y * bf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Converted two vectors of 8 half's or bf16's into float, then take their dot product. -template -inline __device__ float hmulsum8(const uint4 a, const uint4 b) { - float sum; - sum = flash::hfma2_to_float(a.x, b.x); - sum += flash::hfma2_to_float(a.y, b.y); - sum += flash::hfma2_to_float(a.z, b.z); - sum += flash::hfma2_to_float(a.w, b.w); - return sum; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct MaxOp { __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } @@ -173,10 +133,12 @@ static __device__ inline T run(T x, Operator &op) { template + typename TiledMma, typename TiledCopyA, typename TiledCopyB, + typename ThrCopyA, typename ThrCopyB> inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, - TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K @@ -184,13 +146,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } - if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } - if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } @@ -199,19 +161,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 //////////////////////////////////////////////////////////////////////////////////////////////////// template + typename TiledMma, typename TiledCopy, typename ThrCopy> inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { - copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } @@ -225,7 +188,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -241,9 +207,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { static_assert(mma_shape_K == 8 || mma_shape_K == 16); constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -319,9 +289,9 @@ void cp_async_wait() { template -inline __device__ void copy(TiledCopy thr_copy, Tensor const &S, +inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, int max_MN=0) { + Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -335,13 +305,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - copy(thr_copy, S(_, m, k), D(_, m, k)); + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { - clear(D(_, m, k)); + cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { - clear(D(_, m, _)); + cute::clear(D(_, m, _)); } } // TD [2023-04-13]: Strange that the code below can cause race condition. @@ -350,7 +320,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(thr_copy, S(_, m, _), D(_, m, _)); + // copy(tiled_copy, S(_, m, _), D(_, m, _)); // } else if (Clear_OOB_MN) { // clear(D(_, m, _)); // } @@ -362,7 +332,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const & // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(thr_copy, S(_, m, k), D(_, m, k)); + // copy(tiled_copy, S(_, m, k), D(_, m, k)); // } else if (Clear_OOB_MN) { // clear(D(_, m, k)); // } diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index 90f34e434f..ca65520be5 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -7,6 +7,8 @@ extern "C" { v_ptr: *const c_void, o_ptr: *const c_void, softmax_lse_ptr: *const c_void, + alibi_slopes_ptr: *const c_void, + cu_seqlens_q_ptr: *const i32, cu_seqlens_k_ptr: *const i32, @@ -14,6 +16,7 @@ extern "C" { k_batch_stride: u32, v_batch_stride: u32, o_batch_stride: u32, + alibi_slopes_batch_stride: u32, q_row_stride: u32, k_row_stride: u32, @@ -37,8 +40,11 @@ extern "C" { seqlen_q_rounded: u32, seqlen_k_rounded: u32, - is_causal: c_int, is_bf16: c_int, + is_causal: c_int, + + window_size_left: c_int, + window_size_right: c_int, ); } diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 3395bd0d66..21a06b5ecf 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,12 +3,14 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, - pub causal: bool, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, } fn round_multiple(x: usize, m: usize) -> usize { @@ -85,6 +87,51 @@ impl FlashAttn { candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") } + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + let head_size = round_multiple(head_size_og, 8); let head_size_rounded = round_multiple(head_size, 32); let seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -94,9 +141,22 @@ impl FlashAttn { let dst = unsafe { dev.alloc::(elem_count) }.w()?; let softmax_lse = dev.alloc_zeros::(b_sz * num_heads * seqlen_q).w()?; - let causal = if self.causal { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 }; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = seqlen_k as i32; + } + unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void; @@ -109,12 +169,14 @@ impl FlashAttn { v_ptr, dst_ptr, softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(), /* q_batch_stride */ q_stride[0] as u32, /* k_batch_stride */ k_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32, + /* alibi_slopes_batch_stride */ 0, /* q_row_stride */ q_stride[q_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32, @@ -133,8 +195,10 @@ impl FlashAttn { /* seqlen_k */ seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, - /* is_causal */ causal, /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, ) } @@ -197,20 +261,137 @@ pub fn flash_attn( softmax_scale: f32, causal: bool, ) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + let op = FlashAttn { softmax_scale, - causal, + alibi_slopes: None, + window_size_left, + window_size_right, }; q.apply_op3(k, v, op) } -struct FlashAttnVarLen { +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttn { + softmax_scale, + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, softmax_scale: f32, causal: bool, - max_seqlen_q: usize, - max_seqlen_k: usize, - seqlens_q: Tensor, - seqlens_k: Tensor, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +struct FlashAttnVarLen { + pub softmax_scale: f32, + pub max_seqlen_q: usize, + pub max_seqlen_k: usize, + pub seqlens_q: Tensor, + pub seqlens_k: Tensor, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, } impl FlashAttnVarLen { @@ -311,7 +492,54 @@ impl FlashAttnVarLen { if nseqlens_k != nseqlens_q { candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") } + let batch_size = nseqlens_q - 1; + + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + let head_size = round_multiple(head_size_og, 8); let head_size_rounded = round_multiple(head_size, 32); let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); @@ -323,9 +551,22 @@ impl FlashAttnVarLen { .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q) .w()?; - let causal = if self.causal { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 }; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = self.max_seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = self.max_seqlen_k as i32; + } + unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void; @@ -340,12 +581,14 @@ impl FlashAttnVarLen { v_ptr, dst_ptr, softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ seqlens_q_ptr, /* cu_seqlens_k_ptr */ seqlens_k_ptr, /* q_batch_stride */ 0, /* k_batch_stride */ 0, /* v_batch_stride */ 0, /* o_batch_stride */ 0, + /* alibi_slopes_batch_stride */ 0, /* q_row_stride */ q_stride[q_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32, @@ -364,8 +607,10 @@ impl FlashAttnVarLen { /* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, - /* is_causal */ causal, /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, ) } @@ -439,14 +684,177 @@ pub fn flash_attn_varlen( max_seqlen_k: usize, softmax_scale: f32, causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, ) -> Result { let op = FlashAttnVarLen { softmax_scale, - causal, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, }; q.apply_op3(k, v, op) }