Skip to content

Commit e6a37dd

Browse files
authored
Merge branch 'Dao-AILab:main' into main
2 parents 6c787d5 + 23b77c8 commit e6a37dd

File tree

5 files changed

+15
-13
lines changed

5 files changed

+15
-13
lines changed

benchmarks/benchmark_causal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from einops import rearrange, repeat
88

99
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
10-
from src.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
10+
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
1111
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
1212
# # from flash_attn.triton.fused_attention import attention as attention
1313
# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func

csrc/flash_attn/src/flash_bwd_launch_template.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
6060
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
6161
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
6262
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
63-
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
63+
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
6464
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
6565
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
66-
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
66+
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
6767
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
6868
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
6969
// If Is_local, set Is_causal to false
70-
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
71-
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
70+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
71+
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true>;
7272
if (smem_size_dq_dk_dv >= 48 * 1024) {
7373
C10_CUDA_CHECK(cudaFuncSetAttribute(
7474
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
@@ -104,11 +104,11 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
104104
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
105105
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock;
106106
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
107-
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
107+
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
108108
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
109109
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
110110
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
111-
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
111+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
112112
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
113113
if (smem_size_dq_dk_dv >= 48 * 1024) {
114114
C10_CUDA_CHECK(cudaFuncSetAttribute(

csrc/flash_attn/src/flash_fwd_launch_template.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,16 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
4343
const bool return_softmax = params.p_ptr != nullptr;
4444
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
4545
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
46-
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
46+
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
4747
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
4848
// Will only return softmax if dropout, to reduce compilation time.
4949
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
5050
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
5151
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
5252
// If Is_local, set Is_causal to false
53-
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
53+
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
54+
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
55+
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
5456
if (smem_size >= 48 * 1024) {
5557
C10_CUDA_CHECK(cudaFuncSetAttribute(
5658
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
@@ -79,13 +81,13 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
7981
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
8082
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
8183
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
82-
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
84+
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
8385
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
8486
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
8587
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
8688
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
8789
// If Is_local, set Is_causal to false
88-
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
90+
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
8991
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
9092
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
9193
if (smem_size >= 48 * 1024) {

flash_attn/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.3.4"
1+
__version__ = "2.3.5"
22

33
from flash_attn.flash_attn_interface import (
44
flash_attn_func,

training/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
8585
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
8686

8787
# Install FlashAttention
88-
RUN pip install flash-attn==2.3.4
88+
RUN pip install flash-attn==2.3.5
8989

9090
# Install CUDA extensions for fused dense, layer norm
9191
RUN git clone https://github.com/HazyResearch/flash-attention \

0 commit comments

Comments
 (0)