@@ -86,7 +86,8 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
86
86
```
87
87
88
88
``` python
89
- flash_attn_qkvpacked_func(qkv, dropout_p = 0.0 , softmax_scale = None , causal = False , window_size = (- 1 , - 1 )):
89
+ flash_attn_qkvpacked_func(qkv, dropout_p = 0.0 , softmax_scale = None , causal = False ,
90
+ window_size = (- 1 , - 1 ), alibi_slopes = None , deterministic = False ):
90
91
""" dropout_p should be set to 0.0 during evaluation
91
92
If Q, K, V are already stacked into 1 tensor, this function will be faster than
92
93
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
@@ -100,13 +101,18 @@ Arguments:
100
101
Default to 1 / sqrt(headdim).
101
102
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
102
103
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
104
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
105
+ the attention score of query i and key j.
106
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
107
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
103
108
Return:
104
109
out: (batch_size, seqlen, nheads, headdim).
105
110
"""
106
111
```
107
112
108
113
``` python
109
- flash_attn_func(q, k, v, dropout_p = 0.0 , softmax_scale = None , causal = False , window_size = (- 1 , - 1 )):
114
+ flash_attn_func(q, k, v, dropout_p = 0.0 , softmax_scale = None , causal = False ,
115
+ window_size = (- 1 , - 1 ), alibi_slopes = None , deterministic = False ):
110
116
""" dropout_p should be set to 0.0 during evaluation
111
117
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
112
118
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
@@ -125,6 +131,11 @@ Arguments:
125
131
Default to 1 / sqrt(headdim).
126
132
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
127
133
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
134
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
135
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
136
+ is added to the attention score of query i and key j.
137
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
138
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
128
139
Return:
129
140
out: (batch_size, seqlen, nheads, headdim).
130
141
"""
@@ -141,17 +152,23 @@ def flash_attn_with_kvcache(
141
152
rotary_sin = None ,
142
153
cache_seqlens : Optional[Union[(int , torch.Tensor)]] = None ,
143
154
cache_batch_idx : Optional[torch.Tensor] = None ,
155
+ block_table : Optional[torch.Tensor] = None ,
144
156
softmax_scale = None ,
145
157
causal = False ,
146
158
window_size = (- 1 , - 1 ), # -1 means infinite context window
147
159
rotary_interleaved = True ,
160
+ alibi_slopes = None ,
148
161
):
149
162
"""
150
163
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
151
164
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
152
165
the previous step, and update them with the new keys/values from the current step, and do
153
166
attention with the updated cache, all in 1 kernel.
154
167
168
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
169
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
170
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
171
+
155
172
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
156
173
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
157
174
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
@@ -161,12 +178,36 @@ def flash_attn_with_kvcache(
161
178
162
179
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
163
180
181
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
182
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
183
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
184
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
185
+
186
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
187
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
188
+ 1 1 1 1 0
189
+ 1 1 1 1 1
190
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
191
+ 0 0
192
+ 0 0
193
+ 0 0
194
+ 1 0
195
+ 1 1
196
+ If the row of the mask is all zero, the output will be zero.
197
+
198
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
199
+ will only attend to keys between
200
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
201
+
164
202
Note: Does not support backward pass.
165
203
166
204
Arguments:
167
205
q: (batch_size, seqlen, nheads, headdim)
168
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
169
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
206
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
207
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
208
+ page_block_size must be a multiple of 256.
209
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
210
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
170
211
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
171
212
k with k_cache, starting at the indices specified by cache_seqlens.
172
213
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
@@ -175,6 +216,7 @@ def flash_attn_with_kvcache(
175
216
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
176
217
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
177
218
KV cache.
219
+ block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
178
220
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
179
221
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
180
222
If the indices are not distinct, and k and v are provided, the values updated in the cache
@@ -187,10 +229,9 @@ def flash_attn_with_kvcache(
187
229
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
188
230
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
189
231
(i.e. GPT-NeoX style).
190
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
191
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
192
- to automatically determine the number of splits.
193
- Don't change this unless you know what you are doing.
232
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
233
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
234
+ is added to the attention score of query i and key j.
194
235
195
236
Return:
196
237
out: (batch_size, seqlen, nheads, headdim).
@@ -266,6 +307,17 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
266
307
AI] ( https://mistral.ai/ ) and in particular Timothée Lacroix for this
267
308
contribution. Sliding window was used in the [ Mistral 7B] ( https://mistral.ai/news/announcing-mistral-7b/ ) model.
268
309
310
+ ### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
311
+
312
+ Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
313
+
314
+ Implement deterministic backward pass. Thanks to engineers from [ Meituan] ( www.meituan.com ) for this contribution.
315
+
316
+ ### 2.5: Paged KV cache.
317
+
318
+ Support paged KV cache (i.e., [ PagedAttention] ( https://arxiv.org/abs/2309.06180 ) ).
319
+ Thanks to @beginlner for this contribution.
320
+
269
321
## Performance
270
322
271
323
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
0 commit comments