Skip to content

Commit 620c16a

Browse files
committed
Merge remote-tracking branch 'Dao-AILab/main'
2 parents 0c8577d + f86e3dd commit 620c16a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+7293
-391
lines changed

.github/workflows/publish.yml

+32-22
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ jobs:
4343
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
46-
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
47-
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1']
48-
cuda-version: ['11.8.0', '12.4.1']
46+
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
47+
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
48+
cuda-version: ['11.8.0', '12.3.2']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
5151
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
@@ -56,16 +56,22 @@ jobs:
5656
# Pytorch < 2.2 does not support Python 3.12
5757
- torch-version: '2.1.2'
5858
python-version: '3.12'
59-
# Pytorch >= 2.5 does not support Python 3.8
60-
- torch-version: '2.5.1'
61-
python-version: '3.8'
59+
# Pytorch < 2.5 does not support Python 3.13
60+
- torch-version: '2.1.2'
61+
python-version: '3.13'
62+
- torch-version: '2.2.2'
63+
python-version: '3.13'
64+
- torch-version: '2.3.1'
65+
python-version: '3.13'
66+
- torch-version: '2.4.0'
67+
python-version: '3.13'
6268

6369
steps:
6470
- name: Checkout
65-
uses: actions/checkout@v3
71+
uses: actions/checkout@v4
6672

6773
- name: Set up Python
68-
uses: actions/setup-python@v4
74+
uses: actions/setup-python@v5
6975
with:
7076
python-version: ${{ matrix.python-version }}
7177

@@ -74,6 +80,7 @@ jobs:
7480
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
7581
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
7682
echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
83+
echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
7784
7885
- name: Free up disk space
7986
if: ${{ runner.os == 'Linux' }}
@@ -92,37 +99,40 @@ jobs:
9299

93100
- name: Install CUDA ${{ matrix.cuda-version }}
94101
if: ${{ matrix.cuda-version != 'cpu' }}
95-
uses: Jimver/cuda-toolkit@v0.2.18
102+
uses: Jimver/cuda-toolkit@v0.2.19
96103
id: cuda-toolkit
97104
with:
98105
cuda: ${{ matrix.cuda-version }}
99106
linux-local-args: '["--toolkit"]'
100107
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
101108
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
102109
method: 'network'
103-
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
104-
# not just nvcc
105-
# sub-packages: '["nvcc"]'
110+
sub-packages: '["nvcc"]'
106111

107112
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
108113
run: |
109114
pip install --upgrade pip
110-
# If we don't install before installing Pytorch, we get error for torch 2.0.1
111-
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
112-
pip install lit
113115
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
114-
pip install setuptools
116+
pip install setuptools==68.0.0
117+
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
118+
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
119+
pip install typing-extensions==4.12.2
115120
# We want to figure out the CUDA version to download pytorch
116121
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
117122
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
118123
# This code is ugly, maybe there's a better way to do this.
119124
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
120-
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118}[env['MATRIX_TORCH_VERSION']]; \
121-
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124}[env['MATRIX_TORCH_VERSION']]; \
125+
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
126+
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
122127
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
123128
)
124129
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
125-
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
130+
# pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
131+
# Can't use --no-deps because we need cudnn etc.
132+
# Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
133+
pip install jinja2
134+
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
135+
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
126136
else
127137
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
128138
fi
@@ -144,7 +154,7 @@ jobs:
144154
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
145155
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
146156
# Limit MAX_JOBS otherwise the github runner goes OOM
147-
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.3 goes OOM
157+
# nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM
148158
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
149159
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
150160
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
@@ -185,9 +195,9 @@ jobs:
185195
runs-on: ubuntu-latest
186196

187197
steps:
188-
- uses: actions/checkout@v3
198+
- uses: actions/checkout@v4
189199

190-
- uses: actions/setup-python@v4
200+
- uses: actions/setup-python@v5
191201
with:
192202
python-version: '3.10'
193203

README.md

+50-4
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,11 @@ To run the test:
5959
export PYTHONPATH=$PWD
6060
pytest -q -s test_flash_attn.py
6161
```
62-
63-
62+
Once the package is installed, you can import it as follows:
63+
```python
64+
import flash_attn_interface
65+
flash_attn_interface.flash_attn_func()
66+
```
6467

6568
## Installation and features
6669
**Requirements:**
@@ -112,7 +115,7 @@ FlashAttention-2 with CUDA currently supports:
112115
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
113116

114117
### AMD ROCm Support
115-
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
118+
ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.
116119

117120
**Requirements:**
118121
- ROCm 6.0 and above.
@@ -121,11 +124,54 @@ We recommend the
121124
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
122125
container from ROCm, which has all the required tools to install FlashAttention.
123126

124-
FlashAttention-2 with ROCm currently supports:
127+
#### Composable Kernel Backend
128+
FlashAttention-2 ROCm CK backend currently supports:
125129
1. MI200 or MI300 GPUs.
126130
2. Datatype fp16 and bf16
127131
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
128132

133+
#### Triton Backend
134+
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.
135+
136+
It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.
137+
138+
These features are supported in Fwd and Bwd
139+
1) Fwd and Bwd with causal masking
140+
2) Variable sequence lengths
141+
3) Arbitrary Q and KV sequence lengths
142+
4) Arbitrary head sizes
143+
144+
These features are supported in Fwd for now. We will add them to backward soon.
145+
1) Multi and grouped query attention
146+
2) ALiBi and matrix bias
147+
148+
These features are in development
149+
1) Paged Attention
150+
2) Sliding Window
151+
3) Rotary embeddings
152+
4) Dropout
153+
5) Performance Improvements
154+
155+
#### Getting Started
156+
To get started with the triton backend for AMD, follow the steps below.
157+
158+
First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4).
159+
160+
```
161+
git clone https://github.com/triton-lang/triton
162+
cd triton
163+
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
164+
pip install --verbose -e python
165+
```
166+
Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.
167+
168+
```
169+
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
170+
cd flash-attention
171+
python setup.py install
172+
pytest tests/test_flash_attn.py
173+
```
174+
129175

130176
## How to use FlashAttention
131177

csrc/cutlass

Submodule cutlass updated 582 files

0 commit comments

Comments
 (0)