You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: README.md
+61-9
Original file line number
Diff line number
Diff line change
@@ -74,7 +74,7 @@ FlashAttention-2 currently supports:
74
74
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
75
75
GPUs for now.
76
76
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
77
-
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
77
+
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
78
78
79
79
80
80
## How to use FlashAttention
@@ -86,7 +86,8 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
191
-
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
192
-
to automatically determine the number of splits.
193
-
Don't change this unless you know what you are doing.
232
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
233
+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
234
+
is added to the attention score of query i and key j.
194
235
195
236
Return:
196
237
out: (batch_size, seqlen, nheads, headdim).
@@ -266,6 +307,17 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
266
307
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
267
308
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
268
309
310
+
### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
311
+
312
+
Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
313
+
314
+
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
315
+
316
+
### 2.5: Paged KV cache.
317
+
318
+
Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
319
+
Thanks to @beginlner for this contribution.
320
+
269
321
## Performance
270
322
271
323
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
0 commit comments