Skip to content

Commit 7a7a6bd

Browse files
update Gemma attention for TPU (#2130)
* update Gemma attention for TPU * add default fallback for GPU and CPU * add fallback option if not running with JAX and TPU * address review comments * check input signature * remove checking q length * code reformat * handle case when soft cap support is not needed * fix format * add tests for FA calls * fix test * update tests * fix code format * address review comments * Update requirements-jax-cuda.txt * Update gemma_causal_lm_test.py * Update requirements-jax-cuda.txt
1 parent f06ad0f commit 7a7a6bd

File tree

3 files changed

+69
-12
lines changed

3 files changed

+69
-12
lines changed

keras_hub/src/models/gemma/gemma_attention.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import inspect
2+
13
import keras
24
import numpy as np
35
from keras import ops
46

57
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
68
from keras_hub.src.utils.keras_utils import clone_initializer
79
from keras_hub.src.utils.keras_utils import has_flash_attention_support
10+
from keras_hub.src.utils.keras_utils import running_on_tpu
811

912

1013
class CachedGemmaAttention(keras.layers.Layer):
@@ -103,6 +106,18 @@ def _apply_rope(self, x, start_index):
103106
)
104107
return x
105108

109+
def _can_use_flash_attention(self):
110+
if not has_flash_attention_support():
111+
return False
112+
if self.dropout > 0.0:
113+
return False
114+
if self.logit_soft_cap is None:
115+
return True
116+
sig = inspect.signature(ops.dot_product_attention)
117+
# We can currently only run soft capped attention for keras >= 3.10
118+
# and only on TPU.
119+
return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
120+
106121
def _compute_attention(
107122
self,
108123
q,
@@ -118,27 +133,23 @@ def _compute_attention(
118133
query_normalization = 1 / np.sqrt(
119134
self.hidden_dim // self.num_query_heads
120135
)
121-
use_dot_product_attention = not (
122-
self.dropout > 0.0 or (len(q.shape) != 4)
123-
)
124-
if has_flash_attention_support() and use_dot_product_attention:
125-
if self.dropout > 0.0:
126-
raise ValueError(
127-
"Flash attention does not support dropout. "
128-
"Please set `dropout` to 0.0."
129-
)
136+
if self._can_use_flash_attention():
130137
if attention_mask is not None:
131138
attention_mask = ops.expand_dims(attention_mask, axis=1)
132139
attention_mask = ops.cast(attention_mask, dtype="bool")
133-
134-
attention_output = ops.dot_product_attention(
140+
# Only pass soft cap if needed as not all keras versions support.
141+
if self.logit_soft_cap:
142+
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
143+
else:
144+
kwargs = {}
145+
return ops.dot_product_attention(
135146
query=q,
136147
key=k,
137148
value=v,
138149
mask=attention_mask,
139150
scale=query_normalization,
151+
**kwargs,
140152
)
141-
return attention_output
142153

143154
q *= ops.cast(query_normalization, dtype=q.dtype)
144155
q_shape = ops.shape(q)

keras_hub/src/models/gemma/gemma_causal_lm_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
)
1313
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
1414
from keras_hub.src.tests.test_case import TestCase
15+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
16+
from keras_hub.src.utils.keras_utils import running_on_gpu
1517

1618

1719
class GemmaCausalLMTest(TestCase):
@@ -95,6 +97,18 @@ def test_generate(self):
9597
prompt_ids["padding_mask"][:, :4],
9698
)
9799

100+
def test_flash_attention_call(self):
101+
if keras.config.backend() != "jax" or not has_flash_attention_support():
102+
self.skipTest("`flash_attention` testing requires the Jax backend.")
103+
104+
with patch("keras.src.backend.nn.dot_product_attention") as mock_func:
105+
causal_lm = GemmaCausalLM(**self.init_kwargs)
106+
causal_lm.generate("the quick brown fox")
107+
if running_on_gpu():
108+
mock_func.assert_called()
109+
else:
110+
mock_func.assert_not_called()
111+
98112
def test_generate_with_bfloat16(self):
99113
original_floatx = keras.config.floatx()
100114
keras.config.set_floatx("float16")

keras_hub/src/utils/keras_utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,35 @@ def has_flash_attention_support():
7272
return True
7373
else:
7474
return False
75+
76+
77+
def running_on_tpu():
78+
backend = keras.config.backend()
79+
if backend == "jax":
80+
import jax
81+
82+
devices = jax.devices()
83+
return any(d.platform == "tpu" for d in devices)
84+
elif backend == "tensorflow":
85+
import tensorflow as tf
86+
87+
return bool(tf.config.list_logical_devices("TPU"))
88+
elif backend == "torch":
89+
return False
90+
91+
92+
def running_on_gpu():
93+
backend = keras.config.backend()
94+
if backend == "jax":
95+
import jax
96+
97+
devices = jax.devices()
98+
return any(d.platform == "gpu" for d in devices)
99+
elif backend == "tensorflow":
100+
import tensorflow as tf
101+
102+
return bool(tf.config.list_logical_devices("GPU"))
103+
elif backend == "torch":
104+
import torch
105+
106+
return torch.cuda.is_available()

0 commit comments

Comments
 (0)