Skip to content

Commit 6a4b69b

Browse files
committed
Merge remote-tracking branch 'bdashore3/main' into bdashore3
2 parents f24e91b + 1247caa commit 6a4b69b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+4901
-3059
lines changed

.github/workflows/build-wheels.yml

+5-6
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ permissions:
2121

2222
jobs:
2323
build_wheels:
24-
name: Build wheels for Python ${{ matrix.pyver }} and CUDA ${{ matrix.cuda }}
24+
name: Build wheels for Python ${{ matrix.pyver }}, CUDA ${{ matrix.cuda }}, and Torch ${{ matrix.torchver }}
2525
runs-on: windows-latest
2626
strategy:
2727
matrix:
2828
pyver: ["3.10", "3.11"]
29-
cuda: ["12.1.1"]
29+
cuda: ["12.2.2"]
30+
torchver: ["2.1.2", "2.2.0"]
3031
defaults:
3132
run:
3233
shell: pwsh
@@ -37,8 +38,6 @@ jobs:
3738
steps:
3839
- uses: actions/checkout@v4
3940
with:
40-
repository: 'Dao-AILab/flash-attention'
41-
ref: ${{ inputs.version }}
4241
submodules: 'recursive'
4342

4443
- uses: actions/setup-python@v4
@@ -67,7 +66,7 @@ jobs:
6766
if (!(mamba list cuda)[-1].contains('cuda')) {sleep -s 10; mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()}
6867
if (!(mamba list cuda)[-1].contains('cuda')) {throw 'CUDA Toolkit failed to install!'}
6968
70-
python -m pip install --upgrade build setuptools wheel packaging ninja torch==2.1.0 --extra-index-url "https://download.pytorch.org/whl/cu$cudaVersionPytorch"
69+
python -m pip install --upgrade build setuptools wheel packaging ninja torch==${{ matrix.torchver }} --extra-index-url "https://download.pytorch.org/whl/cu121"
7170
7271
- name: Build Wheel
7372
id: build-wheel
@@ -85,7 +84,7 @@ jobs:
8584
python -m build -n --wheel
8685
8786
$wheel = (gi '.\dist\*.whl')[0]
88-
$wheelname = $wheel.name.replace("flash_attn-$packageVersion-","flash_attn-$packageVersion+cu$cudaVersion"+"torch2.1cxx11abiFALSE-")
87+
$wheelname = $wheel.name.replace("flash_attn-$packageVersion-","flash_attn-$packageVersion+cu$cudaVersion"+"torch${{ matrix.torchver }}cxx11abiFALSE-")
8988
Move-Item $wheel.fullname ".\dist\$wheelname"
9089
9190
- uses: actions/upload-artifact@v3

.github/workflows/publish.yml

+33-30
Original file line numberDiff line numberDiff line change
@@ -44,45 +44,34 @@ jobs:
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
4646
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
47-
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0']
48-
cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0', '12.2.0']
47+
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240126']
48+
cuda-version: ['11.8.0', '12.2.2']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
5151
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
5252
# when building without C++11 ABI and using it on nvcr images.
5353
cxx11_abi: ['FALSE', 'TRUE']
5454
exclude:
55+
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
5556
# Pytorch <= 1.12 does not support Python 3.11
5657
- torch-version: '1.12.1'
5758
python-version: '3.11'
5859
# Pytorch >= 2.0 only supports Python >= 3.8
5960
- torch-version: '2.0.1'
6061
python-version: '3.7'
61-
- torch-version: '2.1.0'
62+
- torch-version: '2.1.2'
63+
python-version: '3.7'
64+
- torch-version: '2.2.0'
65+
python-version: '3.7'
66+
- torch-version: '2.3.0.dev20240126'
6267
python-version: '3.7'
6368
# Pytorch <= 2.0 only supports CUDA <= 11.8
6469
- torch-version: '1.12.1'
65-
cuda-version: '12.1.0'
66-
- torch-version: '1.12.1'
67-
cuda-version: '12.2.0'
68-
- torch-version: '1.13.1'
69-
cuda-version: '12.1.0'
70+
cuda-version: '12.2.2'
7071
- torch-version: '1.13.1'
71-
cuda-version: '12.2.0'
72+
cuda-version: '12.2.2'
7273
- torch-version: '2.0.1'
73-
cuda-version: '12.1.0'
74-
- torch-version: '2.0.1'
75-
cuda-version: '12.2.0'
76-
# Pytorch >= 2.1 only supports CUDA >= 11.8
77-
- torch-version: '2.1.0'
78-
cuda-version: '11.6.2'
79-
- torch-version: '2.1.0'
80-
cuda-version: '11.7.1'
81-
# Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so
82-
# we only use CUDA 12.2. setup.py as a special case that will
83-
# download the wheel for CUDA 12.2 instead.
84-
- torch-version: '2.1.0'
85-
cuda-version: '12.1.0'
74+
cuda-version: '12.2.2'
8675

8776
steps:
8877
- name: Checkout
@@ -97,6 +86,7 @@ jobs:
9786
run: |
9887
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
9988
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
89+
echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
10090
10191
- name: Free up disk space
10292
if: ${{ runner.os == 'Linux' }}
@@ -107,9 +97,15 @@ jobs:
10797
sudo rm -rf /opt/ghc
10898
sudo rm -rf /opt/hostedtoolcache/CodeQL
10999
100+
- name: Set up swap space
101+
if: runner.os == 'Linux'
102+
uses: pierotofy/set-swap-space@v1.0
103+
with:
104+
swap-size-gb: 10
105+
110106
- name: Install CUDA ${{ matrix.cuda-version }}
111107
if: ${{ matrix.cuda-version != 'cpu' }}
112-
uses: Jimver/cuda-toolkit@v0.2.11
108+
uses: Jimver/cuda-toolkit@v0.2.14
113109
id: cuda-toolkit
114110
with:
115111
cuda: ${{ matrix.cuda-version }}
@@ -129,10 +125,21 @@ jobs:
129125
pip install lit
130126
# We want to figure out the CUDA version to download pytorch
131127
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
128+
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
132129
# This code is ugly, maybe there's a better way to do this.
133-
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
130+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
131+
minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \
132+
maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \
133+
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
134+
)
134135
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
135-
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
136+
if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then
137+
# --no-deps because we can't install old versions of pytorch-triton
138+
pip install typing-extensions jinja2
139+
pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
140+
else
141+
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
142+
fi
136143
else
137144
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
138145
fi
@@ -153,12 +160,8 @@ jobs:
153160
pip install ninja packaging wheel
154161
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
155162
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
156-
# Currently for this setting the runner goes OOM if we pass --threads 4 to nvcc
157-
if [[ ( ${MATRIX_CUDA_VERSION} == "121" || ${MATRIX_CUDA_VERSION} == "122" ) && ${MATRIX_TORCH_VERSION} == "2.1" ]]; then
158-
export FLASH_ATTENTION_FORCE_SINGLE_THREAD="TRUE"
159-
fi
160163
# Limit MAX_JOBS otherwise the github runner goes OOM
161-
MAX_JOBS=1 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
164+
MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
162165
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
163166
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
164167
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}

README.md

+60-8
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
8686
```
8787

8888
```python
89-
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
89+
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
90+
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
9091
"""dropout_p should be set to 0.0 during evaluation
9192
If Q, K, V are already stacked into 1 tensor, this function will be faster than
9293
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
@@ -100,13 +101,18 @@ Arguments:
100101
Default to 1 / sqrt(headdim).
101102
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
102103
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
104+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
105+
the attention score of query i and key j.
106+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
107+
which is slightly slower and uses more memory. The forward pass is always deterministic.
103108
Return:
104109
out: (batch_size, seqlen, nheads, headdim).
105110
"""
106111
```
107112

108113
```python
109-
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
114+
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
115+
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
110116
"""dropout_p should be set to 0.0 during evaluation
111117
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
112118
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
@@ -125,6 +131,11 @@ Arguments:
125131
Default to 1 / sqrt(headdim).
126132
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
127133
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
134+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
135+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
136+
is added to the attention score of query i and key j.
137+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
138+
which is slightly slower and uses more memory. The forward pass is always deterministic.
128139
Return:
129140
out: (batch_size, seqlen, nheads, headdim).
130141
"""
@@ -141,17 +152,23 @@ def flash_attn_with_kvcache(
141152
rotary_sin=None,
142153
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
143154
cache_batch_idx: Optional[torch.Tensor] = None,
155+
block_table: Optional[torch.Tensor] = None,
144156
softmax_scale=None,
145157
causal=False,
146158
window_size=(-1, -1), # -1 means infinite context window
147159
rotary_interleaved=True,
160+
alibi_slopes=None,
148161
):
149162
"""
150163
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
151164
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
152165
the previous step, and update them with the new keys/values from the current step, and do
153166
attention with the updated cache, all in 1 kernel.
154167
168+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
169+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
170+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
171+
155172
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
156173
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
157174
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
@@ -161,12 +178,36 @@ def flash_attn_with_kvcache(
161178
162179
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
163180
181+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
182+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
183+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
184+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
185+
186+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
187+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
188+
1 1 1 1 0
189+
1 1 1 1 1
190+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
191+
0 0
192+
0 0
193+
0 0
194+
1 0
195+
1 1
196+
If the row of the mask is all zero, the output will be zero.
197+
198+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
199+
will only attend to keys between
200+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
201+
164202
Note: Does not support backward pass.
165203
166204
Arguments:
167205
q: (batch_size, seqlen, nheads, headdim)
168-
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
169-
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
206+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
207+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
208+
page_block_size must be a multiple of 256.
209+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
210+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
170211
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
171212
k with k_cache, starting at the indices specified by cache_seqlens.
172213
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
@@ -175,6 +216,7 @@ def flash_attn_with_kvcache(
175216
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
176217
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
177218
KV cache.
219+
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
178220
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
179221
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
180222
If the indices are not distinct, and k and v are provided, the values updated in the cache
@@ -187,10 +229,9 @@ def flash_attn_with_kvcache(
187229
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
188230
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
189231
(i.e. GPT-NeoX style).
190-
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
191-
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
192-
to automatically determine the number of splits.
193-
Don't change this unless you know what you are doing.
232+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
233+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
234+
is added to the attention score of query i and key j.
194235
195236
Return:
196237
out: (batch_size, seqlen, nheads, headdim).
@@ -266,6 +307,17 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
266307
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
267308
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
268309

310+
### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
311+
312+
Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
313+
314+
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
315+
316+
### 2.5: Paged KV cache.
317+
318+
Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
319+
Thanks to @beginlner for this contribution.
320+
269321
## Performance
270322

271323
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).

0 commit comments

Comments
 (0)