Skip to content

Commit b5f5399

Browse files
committed
Merge remote-tracking branch 'Dao-AILab/main'
2 parents f24e91b + 6c9e60d commit b5f5399

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+4972
-3086
lines changed

.github/workflows/publish.yml

+33-30
Original file line numberDiff line numberDiff line change
@@ -44,45 +44,34 @@ jobs:
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
4646
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
47-
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0']
48-
cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0', '12.2.0']
47+
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105']
48+
cuda-version: ['11.8.0', '12.2.2']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
5151
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
5252
# when building without C++11 ABI and using it on nvcr images.
5353
cxx11_abi: ['FALSE', 'TRUE']
5454
exclude:
55+
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
5556
# Pytorch <= 1.12 does not support Python 3.11
5657
- torch-version: '1.12.1'
5758
python-version: '3.11'
5859
# Pytorch >= 2.0 only supports Python >= 3.8
5960
- torch-version: '2.0.1'
6061
python-version: '3.7'
61-
- torch-version: '2.1.0'
62+
- torch-version: '2.1.2'
63+
python-version: '3.7'
64+
- torch-version: '2.2.0'
65+
python-version: '3.7'
66+
- torch-version: '2.3.0.dev20240105'
6267
python-version: '3.7'
6368
# Pytorch <= 2.0 only supports CUDA <= 11.8
6469
- torch-version: '1.12.1'
65-
cuda-version: '12.1.0'
66-
- torch-version: '1.12.1'
67-
cuda-version: '12.2.0'
68-
- torch-version: '1.13.1'
69-
cuda-version: '12.1.0'
70+
cuda-version: '12.2.2'
7071
- torch-version: '1.13.1'
71-
cuda-version: '12.2.0'
72+
cuda-version: '12.2.2'
7273
- torch-version: '2.0.1'
73-
cuda-version: '12.1.0'
74-
- torch-version: '2.0.1'
75-
cuda-version: '12.2.0'
76-
# Pytorch >= 2.1 only supports CUDA >= 11.8
77-
- torch-version: '2.1.0'
78-
cuda-version: '11.6.2'
79-
- torch-version: '2.1.0'
80-
cuda-version: '11.7.1'
81-
# Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so
82-
# we only use CUDA 12.2. setup.py as a special case that will
83-
# download the wheel for CUDA 12.2 instead.
84-
- torch-version: '2.1.0'
85-
cuda-version: '12.1.0'
74+
cuda-version: '12.2.2'
8675

8776
steps:
8877
- name: Checkout
@@ -97,6 +86,7 @@ jobs:
9786
run: |
9887
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
9988
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
89+
echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
10090
10191
- name: Free up disk space
10292
if: ${{ runner.os == 'Linux' }}
@@ -107,9 +97,15 @@ jobs:
10797
sudo rm -rf /opt/ghc
10898
sudo rm -rf /opt/hostedtoolcache/CodeQL
10999
100+
- name: Set up swap space
101+
if: runner.os == 'Linux'
102+
uses: pierotofy/set-swap-space@v1.0
103+
with:
104+
swap-size-gb: 10
105+
110106
- name: Install CUDA ${{ matrix.cuda-version }}
111107
if: ${{ matrix.cuda-version != 'cpu' }}
112-
uses: Jimver/cuda-toolkit@v0.2.11
108+
uses: Jimver/cuda-toolkit@v0.2.14
113109
id: cuda-toolkit
114110
with:
115111
cuda: ${{ matrix.cuda-version }}
@@ -129,10 +125,21 @@ jobs:
129125
pip install lit
130126
# We want to figure out the CUDA version to download pytorch
131127
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
128+
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
132129
# This code is ugly, maybe there's a better way to do this.
133-
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
130+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
131+
minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \
132+
maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \
133+
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
134+
)
134135
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
135-
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
136+
if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then
137+
# --no-deps because we can't install old versions of pytorch-triton
138+
pip install typing-extensions jinja2
139+
pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
140+
else
141+
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
142+
fi
136143
else
137144
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
138145
fi
@@ -153,12 +160,8 @@ jobs:
153160
pip install ninja packaging wheel
154161
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
155162
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
156-
# Currently for this setting the runner goes OOM if we pass --threads 4 to nvcc
157-
if [[ ( ${MATRIX_CUDA_VERSION} == "121" || ${MATRIX_CUDA_VERSION} == "122" ) && ${MATRIX_TORCH_VERSION} == "2.1" ]]; then
158-
export FLASH_ATTENTION_FORCE_SINGLE_THREAD="TRUE"
159-
fi
160163
# Limit MAX_JOBS otherwise the github runner goes OOM
161-
MAX_JOBS=1 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
164+
MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
162165
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
163166
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
164167
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}

README.md

+61-9
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ FlashAttention-2 currently supports:
7474
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
7575
GPUs for now.
7676
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
77-
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
77+
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
7878

7979

8080
## How to use FlashAttention
@@ -86,7 +86,8 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
8686
```
8787

8888
```python
89-
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
89+
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
90+
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
9091
"""dropout_p should be set to 0.0 during evaluation
9192
If Q, K, V are already stacked into 1 tensor, this function will be faster than
9293
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
@@ -100,13 +101,18 @@ Arguments:
100101
Default to 1 / sqrt(headdim).
101102
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
102103
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
104+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
105+
the attention score of query i and key j.
106+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
107+
which is slightly slower and uses more memory. The forward pass is always deterministic.
103108
Return:
104109
out: (batch_size, seqlen, nheads, headdim).
105110
"""
106111
```
107112

108113
```python
109-
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
114+
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
115+
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
110116
"""dropout_p should be set to 0.0 during evaluation
111117
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
112118
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
@@ -125,6 +131,11 @@ Arguments:
125131
Default to 1 / sqrt(headdim).
126132
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
127133
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
134+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
135+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
136+
is added to the attention score of query i and key j.
137+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
138+
which is slightly slower and uses more memory. The forward pass is always deterministic.
128139
Return:
129140
out: (batch_size, seqlen, nheads, headdim).
130141
"""
@@ -141,17 +152,23 @@ def flash_attn_with_kvcache(
141152
rotary_sin=None,
142153
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
143154
cache_batch_idx: Optional[torch.Tensor] = None,
155+
block_table: Optional[torch.Tensor] = None,
144156
softmax_scale=None,
145157
causal=False,
146158
window_size=(-1, -1), # -1 means infinite context window
147159
rotary_interleaved=True,
160+
alibi_slopes=None,
148161
):
149162
"""
150163
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
151164
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
152165
the previous step, and update them with the new keys/values from the current step, and do
153166
attention with the updated cache, all in 1 kernel.
154167
168+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
169+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
170+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
171+
155172
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
156173
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
157174
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
@@ -161,12 +178,36 @@ def flash_attn_with_kvcache(
161178
162179
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
163180
181+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
182+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
183+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
184+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
185+
186+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
187+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
188+
1 1 1 1 0
189+
1 1 1 1 1
190+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
191+
0 0
192+
0 0
193+
0 0
194+
1 0
195+
1 1
196+
If the row of the mask is all zero, the output will be zero.
197+
198+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
199+
will only attend to keys between
200+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
201+
164202
Note: Does not support backward pass.
165203
166204
Arguments:
167205
q: (batch_size, seqlen, nheads, headdim)
168-
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
169-
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
206+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
207+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
208+
page_block_size must be a multiple of 256.
209+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
210+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
170211
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
171212
k with k_cache, starting at the indices specified by cache_seqlens.
172213
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
@@ -175,6 +216,7 @@ def flash_attn_with_kvcache(
175216
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
176217
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
177218
KV cache.
219+
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
178220
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
179221
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
180222
If the indices are not distinct, and k and v are provided, the values updated in the cache
@@ -187,10 +229,9 @@ def flash_attn_with_kvcache(
187229
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
188230
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
189231
(i.e. GPT-NeoX style).
190-
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
191-
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
192-
to automatically determine the number of splits.
193-
Don't change this unless you know what you are doing.
232+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
233+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
234+
is added to the attention score of query i and key j.
194235
195236
Return:
196237
out: (batch_size, seqlen, nheads, headdim).
@@ -266,6 +307,17 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
266307
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
267308
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
268309

310+
### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
311+
312+
Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
313+
314+
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
315+
316+
### 2.5: Paged KV cache.
317+
318+
Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
319+
Thanks to @beginlner for this contribution.
320+
269321
## Performance
270322

271323
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).

0 commit comments

Comments
 (0)