Skip to content

Commit 0c8577d

Browse files
committed
Merge remote-tracking branch 'Dao-AILab/main'
2 parents 007f06e + 7153673 commit 0c8577d

File tree

117 files changed

+12917
-4057
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+12917
-4057
lines changed

.github/workflows/publish.yml

+11-12
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ jobs:
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
4646
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
47-
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0']
48-
cuda-version: ['11.8.0', '12.3.2']
47+
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1']
48+
cuda-version: ['11.8.0', '12.4.1']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
5151
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
@@ -54,13 +54,11 @@ jobs:
5454
exclude:
5555
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
5656
# Pytorch < 2.2 does not support Python 3.12
57-
- torch-version: '2.0.1'
58-
python-version: '3.12'
5957
- torch-version: '2.1.2'
6058
python-version: '3.12'
61-
# Pytorch <= 2.0 only supports CUDA <= 11.8
62-
- torch-version: '2.0.1'
63-
cuda-version: '12.3.2'
59+
# Pytorch >= 2.5 does not support Python 3.8
60+
- torch-version: '2.5.1'
61+
python-version: '3.8'
6462

6563
steps:
6664
- name: Checkout
@@ -75,6 +73,7 @@ jobs:
7573
run: |
7674
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
7775
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
76+
echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
7877
7978
- name: Free up disk space
8079
if: ${{ runner.os == 'Linux' }}
@@ -93,7 +92,7 @@ jobs:
9392

9493
- name: Install CUDA ${{ matrix.cuda-version }}
9594
if: ${{ matrix.cuda-version != 'cpu' }}
96-
uses: Jimver/cuda-toolkit@v0.2.14
95+
uses: Jimver/cuda-toolkit@v0.2.18
9796
id: cuda-toolkit
9897
with:
9998
cuda: ${{ matrix.cuda-version }}
@@ -118,9 +117,9 @@ jobs:
118117
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
119118
# This code is ugly, maybe there's a better way to do this.
120119
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
121-
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
122-
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
123-
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
120+
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118}[env['MATRIX_TORCH_VERSION']]; \
121+
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124}[env['MATRIX_TORCH_VERSION']]; \
122+
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
124123
)
125124
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
126125
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
@@ -147,7 +146,7 @@ jobs:
147146
# Limit MAX_JOBS otherwise the github runner goes OOM
148147
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.3 goes OOM
149148
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
150-
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
149+
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
151150
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
152151
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
153152
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV

README.md

+40-42
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,12 @@ This is a beta release for testing / benchmarking before we integrate that with
4343
the rest of the repo.
4444

4545
Currently released:
46-
- FP16 forward and backward
47-
48-
Coming soon in the next couple of days / next week:
49-
- BF16
50-
- Variable length (FP16, BF16)
51-
- FP8 forward.
46+
- FP16 / BF16 forward and backward, FP8 forward
5247

5348
Requirements: H100 / H800 GPU, CUDA >= 12.3.
5449

50+
For now, we highly recommend CUDA 12.3 for best performance.
51+
5552
To install:
5653
```sh
5754
cd hopper
@@ -66,26 +63,21 @@ pytest -q -s test_flash_attn.py
6663

6764

6865
## Installation and features
69-
70-
Requirements:
71-
- CUDA 11.6 and above.
66+
**Requirements:**
67+
- CUDA toolkit or ROCm toolkit
7268
- PyTorch 1.12 and above.
69+
- `packaging` Python package (`pip install packaging`)
70+
- `ninja` Python package (`pip install ninja`) *
7371
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
7472

75-
We recommend the
76-
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
77-
container from Nvidia, which has all the required tools to install FlashAttention.
78-
79-
To install:
80-
1. Make sure that PyTorch is installed.
81-
2. Make sure that `packaging` is installed (`pip install packaging`)
82-
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
73+
\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
8374
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
8475
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
8576
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
8677
compiling can take a very long time (2h) since it does not use multiple CPU
87-
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
88-
4. Then:
78+
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.
79+
80+
**To install:**
8981
```sh
9082
pip install flash-attn --no-build-isolation
9183
```
@@ -102,15 +94,38 @@ variable `MAX_JOBS`:
10294
MAX_JOBS=4 pip install flash-attn --no-build-isolation
10395
```
10496

105-
Interface: `src/flash_attention_interface.py`
97+
**Interface:** `src/flash_attention_interface.py`
98+
99+
### NVIDIA CUDA Support
100+
**Requirements:**
101+
- CUDA 11.7 and above.
106102

107-
FlashAttention-2 currently supports:
103+
We recommend the
104+
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
105+
container from Nvidia, which has all the required tools to install FlashAttention.
106+
107+
FlashAttention-2 with CUDA currently supports:
108108
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
109109
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
110110
GPUs for now.
111111
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
112112
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
113113

114+
### AMD ROCm Support
115+
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
116+
117+
**Requirements:**
118+
- ROCm 6.0 and above.
119+
120+
We recommend the
121+
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
122+
container from ROCm, which has all the required tools to install FlashAttention.
123+
124+
FlashAttention-2 with ROCm currently supports:
125+
1. MI200 or MI300 GPUs.
126+
2. Datatype fp16 and bf16
127+
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
128+
114129

115130
## How to use FlashAttention
116131

@@ -358,6 +373,10 @@ Thanks to @beginlner for this contribution.
358373
Support attention with softcapping, as used in Gemma-2 and Grok models.
359374
Thanks to @Narsil and @lucidrains for this contribution.
360375

376+
### 2.7: Compatibility with torch compile
377+
378+
Thanks to @ani300 for this contribution.
379+
361380
## Performance
362381

363382
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
@@ -437,27 +456,6 @@ This new release of FlashAttention-2 has been tested on several GPT-style
437456
models, mostly on A100 GPUs.
438457

439458
If you encounter bugs, please open a GitHub Issue!
440-
## AMD GPU/ROCm Support
441-
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
442-
443-
## Installation and features
444-
Requirements:
445-
- ROCm 6.0+
446-
- PyTorch 1.12.1+
447-
448-
We recommend the
449-
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
450-
container from ROCm, which has all the required tools to install FlashAttention.
451-
452-
To compile from source:
453-
```sh
454-
python setup.py install
455-
```
456-
457-
FlashAttention-2 on ROCm currently supports:
458-
1. MI200 or MI300 GPUs.
459-
2. Datatype fp16 and bf16
460-
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
461459

462460
## Tests
463461
To run the tests:

benchmarks/benchmark_gemm.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import time
2+
import torch
3+
import torch.utils.benchmark as benchmark
4+
5+
from triton.testing import do_bench
6+
7+
8+
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
9+
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
10+
if verbose:
11+
print(desc, '- Forward pass')
12+
t = benchmark.Timer(
13+
stmt='fn(*inputs, **kwinputs)',
14+
globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
15+
num_threads=torch.get_num_threads(),
16+
)
17+
m = t.timeit(repeats)
18+
if verbose:
19+
print(m)
20+
return t, m
21+
22+
23+
torch.manual_seed(0)
24+
repeats = 30
25+
dtype = torch.float16
26+
device = 'cuda'
27+
verbose = False
28+
m, n = 8192, 8192
29+
30+
tflops_matmul = {}
31+
tflops_matmul1 = {}
32+
for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
33+
a = torch.randn(m, k, device=device, dtype=dtype)
34+
b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
35+
nFLOPS_matmul = 2 * m * n * k
36+
time.sleep(2) # to reduce power throttling
37+
timing = benchmark_forward(torch.matmul, a, b, desc='cuBLAS', verbose=verbose, repeats=repeats)[1]
38+
tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12
39+
print(f'[torch.utils.benchmark] cuBLAS, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')
40+
time.sleep(2) # to reduce power throttling
41+
ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
42+
tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9
43+
print(f'[triton.test.do_bench] cuBLAS, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')

csrc/composable_kernel

Submodule composable_kernel updated 921 files

0 commit comments

Comments
 (0)