You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
container from Nvidia, which has all the required tools to install FlashAttention.
106
+
107
+
FlashAttention-2 with CUDA currently supports:
108
108
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
109
109
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
110
110
GPUs for now.
111
111
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
112
112
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
113
113
114
+
### AMD ROCm Support
115
+
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
116
+
117
+
**Requirements:**
118
+
- ROCm 6.0 and above.
119
+
120
+
We recommend the
121
+
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
122
+
container from ROCm, which has all the required tools to install FlashAttention.
123
+
124
+
FlashAttention-2 with ROCm currently supports:
125
+
1. MI200 or MI300 GPUs.
126
+
2. Datatype fp16 and bf16
127
+
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
128
+
114
129
115
130
## How to use FlashAttention
116
131
@@ -358,6 +373,10 @@ Thanks to @beginlner for this contribution.
358
373
Support attention with softcapping, as used in Gemma-2 and Grok models.
359
374
Thanks to @Narsil and @lucidrains for this contribution.
360
375
376
+
### 2.7: Compatibility with torch compile
377
+
378
+
Thanks to @ani300 for this contribution.
379
+
361
380
## Performance
362
381
363
382
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
@@ -437,27 +456,6 @@ This new release of FlashAttention-2 has been tested on several GPT-style
437
456
models, mostly on A100 GPUs.
438
457
439
458
If you encounter bugs, please open a GitHub Issue!
440
-
## AMD GPU/ROCm Support
441
-
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
442
-
443
-
## Installation and features
444
-
Requirements:
445
-
- ROCm 6.0+
446
-
- PyTorch 1.12.1+
447
-
448
-
We recommend the
449
-
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
450
-
container from ROCm, which has all the required tools to install FlashAttention.
451
-
452
-
To compile from source:
453
-
```sh
454
-
python setup.py install
455
-
```
456
-
457
-
FlashAttention-2 on ROCm currently supports:
458
-
1. MI200 or MI300 GPUs.
459
-
2. Datatype fp16 and bf16
460
-
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
0 commit comments