Skip to content

Commit 63863ab

Browse files
Use Flash Attention if available (#2058)
* Use Flash Attention if available * Torch's `dot_product_attention` doesn't support `bias`.
1 parent 221ea6b commit 63863ab

File tree

7 files changed

+110
-25
lines changed

7 files changed

+110
-25
lines changed

keras_hub/src/models/falcon/falcon_attention.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,19 @@ def call(
110110

111111
attention_scores = ops.einsum("bqnh,bknh->bnqk", query, key)
112112
attention_scores = ops.add(attention_scores, alibi)
113-
attention_scores = (
114-
attention_scores * self.inv_norm_factor
115-
) # [batch_size, num_heads, query_length, kv_length]
113+
# [batch_size, num_heads, query_length, kv_length]
114+
attention_scores = ops.multiply(
115+
attention_scores,
116+
ops.cast(self.inv_norm_factor, self.compute_dtype),
117+
)
116118
attention_scores = self.softmax(
117119
attention_scores, ops.expand_dims(attention_mask, 1)
118120
)
119121
attention_scores = self.attention_dropout(attention_scores)
120122
attention_output = ops.einsum(
121123
"bnqk,bknh->bqnh", attention_scores, value
122124
)
125+
123126
attention_output = ops.reshape(
124127
attention_output,
125128
[batch_size, seq_length, self.num_heads * self.head_dim],

keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import math
2+
13
import keras
24
from keras import ops
35

46
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
57
from keras_hub.src.utils.keras_utils import clone_initializer
8+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
69

710

811
class GPTNeoXAttention(keras.layers.Layer):
@@ -58,6 +61,8 @@ def __init__(
5861
self.bias_initializer = keras.initializers.get(bias_initializer)
5962
self.max_sequence_length = max_sequence_length
6063

64+
self._inv_norm_factor = 1.0 / math.sqrt(self.attn_head_size)
65+
6166
def build(self, input_shape):
6267
self._qkv_dense = keras.layers.EinsumDense(
6368
equation="abc,cde->abde",
@@ -120,14 +125,26 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
120125
def _compute_attention(
121126
self, query, key, value, attention_mask=None, training=None
122127
):
123-
attention_scores = ops.einsum("aecd,abcd->acbe", key, query)
128+
if has_flash_attention_support() and self.dropout == 0:
129+
# Use `dot_product_attention` with Flash Attention support if
130+
# available.
131+
if attention_mask is not None:
132+
attention_mask = ops.expand_dims(attention_mask, axis=1)
133+
attention_mask = ops.cast(attention_mask, dtype="bool")
134+
attention_output = ops.dot_product_attention(
135+
query,
136+
key,
137+
value,
138+
mask=attention_mask,
139+
scale=self._inv_norm_factor,
140+
)
141+
return attention_output
124142

125-
norm_factor = ops.sqrt(
126-
ops.convert_to_tensor(self.attn_head_size, self.compute_dtype)
143+
attention_scores = ops.einsum("aecd,abcd->acbe", key, query)
144+
attention_scores = ops.multiply(
145+
attention_scores,
146+
ops.cast(self._inv_norm_factor, self.compute_dtype),
127147
)
128-
129-
attention_scores /= norm_factor
130-
131148
attention_scores = self._masked_softmax(
132149
attention_scores, attention_mask
133150
)

keras_hub/src/models/llama/llama_attention.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import math
2+
13
import keras
24
from keras import ops
35

46
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
57
from keras_hub.src.utils.keras_utils import clone_initializer
8+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
69

710

811
class LlamaAttention(keras.layers.Layer):
@@ -43,7 +46,7 @@ def build(self, inputs_shape):
4346
# h = head dim
4447
hidden_dim = inputs_shape[-1]
4548
head_dim = hidden_dim // self.num_query_heads
46-
self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype))
49+
self._inv_norm_factor = 1.0 / math.sqrt(head_dim)
4750

4851
self._query_dense = keras.layers.EinsumDense(
4952
equation="bqm,muh->bquh",
@@ -182,9 +185,27 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
182185
return self._softmax(attention_scores)
183186

184187
def _compute_attention(self, query, key, value, attention_mask=None):
188+
if has_flash_attention_support():
189+
# Use `dot_product_attention` with Flash Attention support if
190+
# available.
191+
if attention_mask is not None:
192+
attention_mask = ops.expand_dims(attention_mask, axis=1)
193+
attention_mask = ops.cast(attention_mask, dtype="bool")
194+
attention_output = ops.dot_product_attention(
195+
query,
196+
key,
197+
value,
198+
mask=attention_mask,
199+
scale=self._inv_norm_factor,
200+
)
201+
return attention_output
202+
185203
attention_scores = ops.einsum(self._dot_product_equation, query, key)
186204

187-
attention_scores = attention_scores / self._norm_factor
205+
attention_scores = ops.multiply(
206+
attention_scores,
207+
ops.cast(self._inv_norm_factor, self.compute_dtype),
208+
)
188209
attention_scores = self._masked_softmax(
189210
attention_scores, attention_mask
190211
)

keras_hub/src/models/mistral/mistral_attention.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import math
2+
13
import keras
24
from keras import ops
35

46
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
57
from keras_hub.src.utils.keras_utils import clone_initializer
8+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
69

710

811
# This is just a self-attention layer in Mistral. But it can be generalized
@@ -52,6 +55,7 @@ def build(self, inputs_shape):
5255
# h = head dim
5356
self._hidden_dim = inputs_shape[-1]
5457
self._head_dim = self._hidden_dim // self._num_query_heads
58+
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
5559

5660
self._query_dense = keras.layers.EinsumDense(
5761
equation="bqm,muh->bquh",
@@ -192,11 +196,26 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
192196
return self._softmax(attention_scores)
193197

194198
def _compute_attention(self, query, key, value, attention_mask=None):
195-
attention_scores = ops.einsum(self._dot_product_equation, query, key)
196-
197-
norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))
199+
if has_flash_attention_support():
200+
# Use `dot_product_attention` with Flash Attention support if
201+
# available.
202+
if attention_mask is not None:
203+
attention_mask = ops.expand_dims(attention_mask, axis=1)
204+
attention_mask = ops.cast(attention_mask, dtype="bool")
205+
attention_output = ops.dot_product_attention(
206+
query,
207+
key,
208+
value,
209+
mask=attention_mask,
210+
scale=self._inv_norm_factor,
211+
)
212+
return attention_output
198213

199-
attention_scores = attention_scores / norm_factor
214+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
215+
attention_scores = ops.multiply(
216+
attention_scores,
217+
ops.cast(self._inv_norm_factor, self.compute_dtype),
218+
)
200219
attention_scores = self._masked_softmax(
201220
attention_scores, attention_mask
202221
)

keras_hub/src/models/phi3/phi3_attention.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import keras
24
from keras import ops
35

@@ -6,6 +8,7 @@
68
Phi3SuScaledRotaryEmbedding,
79
)
810
from keras_hub.src.utils.keras_utils import clone_initializer
11+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
912

1013

1114
class Phi3Attention(keras.layers.Layer):
@@ -53,7 +56,7 @@ def build(self, inputs_shape):
5356
# h = head dim
5457
hidden_dim = inputs_shape[-1]
5558
head_dim = hidden_dim // self.num_query_heads
56-
self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype))
59+
self._inv_norm_factor = 1.0 / math.sqrt(head_dim)
5760

5861
self.query_dense = keras.layers.EinsumDense(
5962
equation="bqm,muh->bquh",
@@ -214,8 +217,26 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
214217
return self.softmax(attention_scores)
215218

216219
def _compute_attention(self, query, key, value, attention_mask=None):
220+
if has_flash_attention_support():
221+
# Use `dot_product_attention` with Flash Attention support if
222+
# available.
223+
if attention_mask is not None:
224+
attention_mask = ops.expand_dims(attention_mask, axis=1)
225+
attention_mask = ops.cast(attention_mask, dtype="bool")
226+
attention_output = ops.dot_product_attention(
227+
query,
228+
key,
229+
value,
230+
mask=attention_mask,
231+
scale=self._inv_norm_factor,
232+
)
233+
return attention_output
234+
217235
attention_scores = ops.einsum("bquh,bkuh->buqk", query, key)
218-
attention_scores = attention_scores / self._norm_factor
236+
attention_scores = ops.multiply(
237+
attention_scores,
238+
ops.cast(self._inv_norm_factor, self.compute_dtype),
239+
)
219240
attention_scores = self._masked_softmax(
220241
attention_scores, attention_mask
221242
)

keras_hub/src/models/stable_diffusion_3/mmdit.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
88
from keras_hub.src.models.backbone import Backbone
99
from keras_hub.src.utils.keras_utils import gelu_approximate
10+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
1011
from keras_hub.src.utils.keras_utils import standardize_data_format
1112

1213

@@ -770,17 +771,14 @@ def build(self, inputs_shape, context_shape, timestep_embedding_shape):
770771
def _compute_attention(self, query, key, value):
771772
batch_size = ops.shape(query)[0]
772773

773-
# Use the fast path when `ops.dot_product_attention` and flash attention
774-
# are available.
775-
if hasattr(ops, "dot_product_attention") and hasattr(
776-
keras.config, "is_flash_attention_enabled"
777-
):
774+
if has_flash_attention_support():
775+
# Use `dot_product_attention` with Flash Attention support if
776+
# available.
778777
encoded = ops.dot_product_attention(
779778
query,
780779
key,
781780
value,
782781
scale=self._inverse_sqrt_key_dim,
783-
flash_attention=keras.config.is_flash_attention_enabled(),
784782
)
785783
return ops.reshape(
786784
encoded, (batch_size, -1, self.num_heads * self.head_dim)
@@ -793,10 +791,9 @@ def _compute_attention(self, query, key, value):
793791
probs = self.softmax(logits)
794792
probs = ops.cast(probs, self.compute_dtype)
795793
encoded = ops.einsum("BNTS,BSNH->BTNH", probs, value)
796-
encoded = ops.reshape(
794+
return ops.reshape(
797795
encoded, (batch_size, -1, self.num_heads * self.head_dim)
798796
)
799-
return encoded
800797

801798
def call(self, inputs, context, timestep_embedding, training=None):
802799
# Compute pre-attention.

keras_hub/src/utils/keras_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,10 @@ def standardize_data_format(data_format):
5353
f"Received: data_format={data_format}"
5454
)
5555
return data_format
56+
57+
58+
def has_flash_attention_support():
59+
if hasattr(keras.config, "is_flash_attention_enabled"):
60+
return True
61+
else:
62+
return False

0 commit comments

Comments
 (0)