You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: README.md
+50-4
Original file line number
Diff line number
Diff line change
@@ -59,8 +59,11 @@ To run the test:
59
59
export PYTHONPATH=$PWD
60
60
pytest -q -s test_flash_attn.py
61
61
```
62
-
63
-
62
+
Once the package is installed, you can import it as follows:
63
+
```python
64
+
import flash_attn_interface
65
+
flash_attn_interface.flash_attn_func()
66
+
```
64
67
65
68
## Installation and features
66
69
**Requirements:**
@@ -112,7 +115,7 @@ FlashAttention-2 with CUDA currently supports:
112
115
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
113
116
114
117
### AMD ROCm Support
115
-
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel)as the backend. It provides the implementation of FlashAttention-2.
118
+
ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel)(ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.
116
119
117
120
**Requirements:**
118
121
- ROCm 6.0 and above.
@@ -121,11 +124,54 @@ We recommend the
121
124
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
122
125
container from ROCm, which has all the required tools to install FlashAttention.
123
126
124
-
FlashAttention-2 with ROCm currently supports:
127
+
#### Composable Kernel Backend
128
+
FlashAttention-2 ROCm CK backend currently supports:
125
129
1. MI200 or MI300 GPUs.
126
130
2. Datatype fp16 and bf16
127
131
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
128
132
133
+
#### Triton Backend
134
+
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.
135
+
136
+
It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.
137
+
138
+
These features are supported in Fwd and Bwd
139
+
1) Fwd and Bwd with causal masking
140
+
2) Variable sequence lengths
141
+
3) Arbitrary Q and KV sequence lengths
142
+
4) Arbitrary head sizes
143
+
144
+
These features are supported in Fwd for now. We will add them to backward soon.
145
+
1) Multi and grouped query attention
146
+
2) ALiBi and matrix bias
147
+
148
+
These features are in development
149
+
1) Paged Attention
150
+
2) Sliding Window
151
+
3) Rotary embeddings
152
+
4) Dropout
153
+
5) Performance Improvements
154
+
155
+
#### Getting Started
156
+
To get started with the triton backend for AMD, follow the steps below.
157
+
158
+
First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4).
0 commit comments