Skip to content

Commit 945ee7b

Browse files
committed
Merge remote-tracking branch 'Dao-AILab/main'
2 parents afef461 + 320fb59 commit 945ee7b

31 files changed

+372
-196
lines changed

.github/workflows/publish.yml

+22-14
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ jobs:
4343
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
46-
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
47-
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105']
46+
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
47+
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.2', '2.3.0', '2.4.0.dev20240407']
4848
cuda-version: ['11.8.0', '12.2.2']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
@@ -53,6 +53,15 @@ jobs:
5353
cxx11_abi: ['FALSE', 'TRUE']
5454
exclude:
5555
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
56+
# Pytorch < 2.2 does not support Python 3.12
57+
- torch-version: '1.12.1'
58+
python-version: '3.12'
59+
- torch-version: '1.13.1'
60+
python-version: '3.12'
61+
- torch-version: '2.0.1'
62+
python-version: '3.12'
63+
- torch-version: '2.1.2'
64+
python-version: '3.12'
5665
# Pytorch <= 1.12 does not support Python 3.11
5766
- torch-version: '1.12.1'
5867
python-version: '3.11'
@@ -61,9 +70,11 @@ jobs:
6170
python-version: '3.7'
6271
- torch-version: '2.1.2'
6372
python-version: '3.7'
64-
- torch-version: '2.2.0'
73+
- torch-version: '2.2.2'
74+
python-version: '3.7'
75+
- torch-version: '2.3.0'
6576
python-version: '3.7'
66-
- torch-version: '2.3.0.dev20240105'
77+
- torch-version: '2.4.0.dev20240407'
6778
python-version: '3.7'
6879
# Pytorch <= 2.0 only supports CUDA <= 11.8
6980
- torch-version: '1.12.1'
@@ -123,23 +134,19 @@ jobs:
123134
# If we don't install before installing Pytorch, we get error for torch 2.0.1
124135
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
125136
pip install lit
137+
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
138+
pip install setuptools
126139
# We want to figure out the CUDA version to download pytorch
127140
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
128141
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
129142
# This code is ugly, maybe there's a better way to do this.
130143
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
131-
minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \
132-
maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \
144+
minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
145+
maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
133146
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
134147
)
135148
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
136-
if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then
137-
# --no-deps because we can't install old versions of pytorch-triton
138-
pip install typing-extensions jinja2
139-
pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
140-
else
141-
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
142-
fi
149+
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
143150
else
144151
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
145152
fi
@@ -161,7 +168,8 @@ jobs:
161168
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
162169
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
163170
# Limit MAX_JOBS otherwise the github runner goes OOM
164-
MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
171+
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.2 goes OOM
172+
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "122" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
165173
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
166174
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
167175
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}

README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -404,12 +404,13 @@ If you use this codebase, or otherwise found our work valuable, please cite:
404404
@inproceedings{dao2022flashattention,
405405
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
406406
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
407-
booktitle={Advances in Neural Information Processing Systems},
407+
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
408408
year={2022}
409409
}
410-
@article{dao2023flashattention2,
410+
@inproceedings{dao2023flashattention2,
411411
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
412412
author={Dao, Tri},
413-
year={2023}
413+
booktitle={International Conference on Learning Representations (ICLR)},
414+
year={2024}
414415
}
415416
```

csrc/cutlass

Submodule cutlass updated 548 files

csrc/flash_attn/flash_api.cpp

+50-12
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ void set_params_fprop(Flash_fwd_params &params,
4646
bool seqlenq_ngroups_swapped=false) {
4747

4848
// Reset the parameters
49-
memset(&params, 0, sizeof(params));
49+
params = {};
5050

5151
params.is_bf16 = q.dtype() == torch::kBFloat16;
5252

@@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
282282
params.num_splits = num_splits;
283283
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
284284
if (num_splits < 1) {
285-
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
285+
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
286+
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
286287
}
287288
if (params.num_splits > 1) {
288289
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
@@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
372373
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
373374
// H/t Daniel Haziza
374375
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
376+
const int ngroups = num_heads / num_heads_k;
375377
if (seqlenq_ngroups_swapped) {
376-
const int ngroups = num_heads / num_heads_k;
377378
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
378379
seqlen_q = ngroups;
379380
num_heads = num_heads_k;
@@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
400401
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
401402
CHECK_DEVICE(out);
402403
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
403-
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
404+
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
405+
if (seqlenq_ngroups_swapped) {
406+
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
407+
}
404408
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
405409
} else {
406410
out = torch::empty_like(q_padded);
@@ -494,12 +498,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
494498

495499
std::vector<at::Tensor>
496500
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
497-
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
498-
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
501+
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
502+
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
499503
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
500504
const at::Tensor &cu_seqlens_q, // b+1
501505
const at::Tensor &cu_seqlens_k, // b+1
502506
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
507+
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
503508
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
504509
int max_seqlen_q,
505510
const int max_seqlen_k,
@@ -535,6 +540,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
535540
CHECK_DEVICE(cu_seqlens_q);
536541
CHECK_DEVICE(cu_seqlens_k);
537542

543+
at::Tensor block_table;
544+
const bool paged_KV = block_table_.has_value();
545+
if (paged_KV) {
546+
block_table = block_table_.value();
547+
CHECK_DEVICE(block_table);
548+
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
549+
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
550+
}
551+
538552
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
539553
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
540554
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@@ -546,8 +560,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
546560
const int batch_size = cu_seqlens_q.numel() - 1;
547561
int num_heads = sizes[1];
548562
const int head_size_og = sizes[2];
549-
const int total_k = k.size(0);
550-
const int num_heads_k = k.size(1);
563+
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
564+
565+
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
566+
const int num_blocks = !paged_KV ? 0 : k.size(0);
567+
const int page_block_size = !paged_KV ? 1 : k.size(1);
568+
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
551569

552570
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
553571
if (is_causal) { window_size_right = 0; }
@@ -557,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
557575
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
558576
// H/t Daniel Haziza
559577
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
578+
const int ngroups = num_heads / num_heads_k;
560579
if (seqlenq_ngroups_swapped) {
561-
const int ngroups = num_heads / num_heads_k;
562580
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
563581
max_seqlen_q = ngroups;
564582
num_heads = num_heads_k;
@@ -575,8 +593,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
575593
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
576594

577595
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
578-
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
579-
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
596+
if (!paged_KV) {
597+
const int total_k = k.size(0);
598+
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
599+
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
600+
} else {
601+
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
602+
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
603+
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
604+
}
605+
580606
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
581607
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
582608
if (seqused_k.has_value()){
@@ -605,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
605631
CHECK_DEVICE(out);
606632
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
607633
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
634+
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
635+
if (seqlenq_ngroups_swapped) {
636+
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
637+
}
608638
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
609639
} else {
610640
out = torch::empty_like(q_padded);
@@ -654,6 +684,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
654684
window_size_left,
655685
window_size_right,
656686
seqlenq_ngroups_swapped);
687+
688+
if (paged_KV) {
689+
params.block_table = block_table.data_ptr<int>();
690+
params.block_table_batch_stride = block_table.stride(0);
691+
params.k_batch_stride = k_padded.stride(0);
692+
params.v_batch_stride = v_padded.stride(0);
693+
}
694+
params.page_block_size = page_block_size;
657695
if (seqlenq_ngroups_swapped) {
658696
// Only apply split-k for decoding
659697
set_params_splitkv(params, batch_size, num_heads,
@@ -682,7 +720,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
682720

683721
if (max_seqlen_k > 0) {
684722
auto stream = at::cuda::getCurrentCUDAStream().stream();
685-
run_mha_fwd(params, stream);
723+
run_mha_fwd(params, stream, paged_KV);
686724
} else {
687725
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
688726
out.zero_();

csrc/flash_attn/src/flash_bwd_kernel.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#pragma once
66

7-
#include <cute/algorithm/copy.hpp>
7+
#include <cute/tensor.hpp>
88

99
#include <cutlass/cutlass.h>
1010
#include <cutlass/array.h>

csrc/flash_attn/src/flash_bwd_launch_template.h

+34-11
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,40 @@
1111
#include "flash_bwd_preprocess_kernel.h"
1212
#include "flash_bwd_kernel.h"
1313

14+
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
15+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
16+
#define ARCH_SUPPORTS_FLASH
17+
#define KERNEL_PARAM_MODIFIER __grid_constant__
18+
#else
19+
#define KERNEL_PARAM_MODIFIER
20+
#endif
21+
22+
// Define a macro for unsupported architecture handling to centralize the error message
23+
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
24+
25+
// Use a macro to clean up kernel definitions
26+
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
27+
template<typename Kernel_traits, __VA_ARGS__> \
28+
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
29+
30+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
31+
#if defined(ARCH_SUPPORTS_FLASH)
32+
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
33+
#else
34+
FLASH_UNSUPPORTED_ARCH
35+
#endif
36+
}
37+
38+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
39+
#if defined(ARCH_SUPPORTS_FLASH)
40+
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
41+
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
42+
#else
43+
FLASH_UNSUPPORTED_ARCH
44+
#endif
45+
}
46+
47+
1448
template<bool Clear_dQaccum=true, typename Kernel_traits>
1549
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
1650
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
@@ -21,17 +55,6 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
2155
flash::clear_dKVaccum<Kernel_traits>(params);
2256
}
2357

24-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
25-
__global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) {
26-
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
27-
}
28-
29-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
30-
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) {
31-
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
32-
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
33-
}
34-
3558
template<typename Kernel_traits>
3659
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
3760
flash::convert_dQ<Kernel_traits>(params, nsplits);

csrc/flash_attn/src/flash_bwd_preprocess_kernel.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#pragma once
66

7-
#include <cute/algorithm/copy.hpp>
7+
#include <cute/tensor.hpp>
88

99
#include <cutlass/cutlass.h>
1010
#include <cutlass/array.h>

0 commit comments

Comments
 (0)