Skip to content

Commit

Permalink
chore: update flash attention kernels (#1518)
Browse files Browse the repository at this point in the history
* chore: update flash attention kernels

* fmt

* remove unused kernels

* force f32

* correct stride
  • Loading branch information
OlivierDehaene authored Jan 5, 2024
1 parent 3a7304c commit 8d1a57c
Show file tree
Hide file tree
Showing 28 changed files with 1,086 additions and 465 deletions.
62 changes: 62 additions & 0 deletions candle-flash-attn/kernels/alibi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <cmath>

#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/array.h>

#include "utils.h"

namespace flash {

using namespace cute;

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Is_causal, typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset,
const int max_seqlen_q,
const int warp_row_stride,
const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}

} // namespace flash
13 changes: 9 additions & 4 deletions candle-flash-attn/kernels/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}

Expand All @@ -32,8 +35,10 @@ struct BlockInfo {

const int sum_s_q;
const int sum_s_k;
const uint32_t actual_seqlen_q;
const uint32_t actual_seqlen_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
54 changes: 42 additions & 12 deletions candle-flash-attn/kernels/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,6 @@
#include <cuda.h>
#include <vector>

// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh>


constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
Expand Down Expand Up @@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params {

// The O matrix (output).
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;

// The stride between rows of O.
index_t o_batch_stride;
Expand All @@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params {

// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;

// The scaling factors for the kernel.
float scale_softmax;
Expand All @@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;

// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;

int *__restrict__ blockmask;

// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;

// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;

// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;

// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;

// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
Expand All @@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params {
float rp_dropout;
float scale_softmax_rp_dropout;

// Random state.
// at::PhiloxCudaState philox_args;
// Local window size
int window_size_left, window_size_right;

bool is_bf16;
bool is_causal;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;

bool is_rotary_interleaved;

int num_splits; // For split-KV version

void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params {

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;

bool deterministic;
index_t dq_accum_split_stride;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);
40 changes: 26 additions & 14 deletions candle-flash-attn/kernels/flash_api.cu
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
#include "flash_fwd_launch_template.h"

// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// FWD_HEADDIM_SWITCH(params.d, [&] {
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
// });
// }

void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] {
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
});
});
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] {
// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// } else {
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
// }
});
});
}

extern "C" void run_mha(
Expand All @@ -20,6 +18,7 @@ extern "C" void run_mha(
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
void *alibi_slopes_ptr,

int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr,
Expand All @@ -28,6 +27,7 @@ extern "C" void run_mha(
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
uint32_t alibi_slopes_batch_stride,

uint32_t q_row_stride,
uint32_t k_row_stride,
Expand All @@ -51,8 +51,11 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,

int is_bf16,
int is_causal,
int is_bf16

int window_size_left,
int window_size_right
) {
Flash_fwd_params params;
// Reset the parameters
Expand All @@ -65,12 +68,14 @@ extern "C" void run_mha(
params.o_ptr = o_ptr;

params.softmax_lse_ptr = softmax_lse_ptr;
params.alibi_slopes_ptr = alibi_slopes_ptr;

// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;

params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
Expand All @@ -92,7 +97,6 @@ extern "C" void run_mha(
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.is_causal = is_causal;

// Set the different scale values.
params.scale_softmax = softmax_scale;
Expand All @@ -106,6 +110,14 @@ extern "C" void run_mha(
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.
params.seqused_k = nullptr;

params.is_causal = is_causal;
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;

params.is_seqlens_k_cumulative = true;
params.num_splits = 1;

cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);
Expand Down
13 changes: 2 additions & 11 deletions candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao.

// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
}
}
26 changes: 2 additions & 24 deletions candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,32 +1,10 @@
// Copyright (c) 2023, Tri Dao.

// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

// template<>
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
// // 1st ones are good for H100, A100
// // 2nd one is good for A6000 bc we get slightly better occupancy
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
// // 1st one is good for H100, A100, A6000
// }
// }

template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
}
}
11 changes: 2 additions & 9 deletions candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao.

// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
}
}
Loading

0 comments on commit 8d1a57c

Please sign in to comment.