8
8
from keras_hub .src .models .gemma .rms_normalization import RMSNormalization
9
9
from keras_hub .src .utils .keras_utils import clone_initializer
10
10
from keras_hub .src .utils .keras_utils import fused_attention_op_available
11
+ from keras_hub .src .utils .keras_utils import gpu_supports_fused_attention_op
12
+ from keras_hub .src .utils .keras_utils import running_on_gpu
11
13
from keras_hub .src .utils .keras_utils import running_on_tpu
12
14
13
15
14
16
class CachedGemma3Attention (keras .layers .Layer ):
15
17
"""A cached grouped query attention layer for Gemma3.
16
18
17
- This is different from Gemma and Gemma2 in several ways:
19
+ This is the same as the attention layer used for Gemma and Gemma2. It
20
+ exposes a few additional args:
18
21
19
- - `use_query_key_norm`: Applies RMS Norm on query, key.
20
- - `rope_wavelength`: RoPE wavelength differs from local to global attention
21
- layers.
22
- - `rope_scaling_factor`: RoPE scaling factor differs from local to global
23
- attention layers.
22
+ `use_query_key_norm`: bool. If True, apply RMS normalization on query
23
+ and key. For Gemma3, this is True.
24
+ `rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
25
+ uses 10K for local attention layers and 1M for global attention layers.
26
+ `gate_dim_reduction`: int. In the gating layers, the output dimension is
27
+ `intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
28
+ value is 2. For Gemma3, it is 1.
29
+
30
+ Moreover, the call() method takes in a `cache_update_mask` so as to make
31
+ sure that the key-value cache is updated only for the non-prompt tokens
32
+ during generation.
24
33
"""
25
34
26
35
def __init__ (
@@ -139,17 +148,22 @@ def _apply_rope(self, x, start_index):
139
148
x = self .rope_layer (x , start_index = start_index )
140
149
return x
141
150
142
- def _can_use_flash_attention (self ):
151
+ def _use_fused_attention_op (self ):
143
152
if not fused_attention_op_available ():
144
153
return False
145
154
if self .dropout > 0.0 :
146
155
return False
147
- if self .logit_soft_cap is None :
148
- return True
149
- sig = inspect .signature (ops .dot_product_attention )
150
- # We can currently only run soft capped attention for keras >= 3.10
151
- # and only on TPU.
152
- return running_on_tpu () and "attn_logits_soft_cap" in sig .parameters
156
+ if running_on_gpu ():
157
+ # GPU never supports softcap in the fused op.
158
+ if self .logit_soft_cap is not None :
159
+ return False
160
+ return gpu_supports_fused_attention_op ()
161
+ elif running_on_tpu ():
162
+ # TPU supports softcap with on keras >= 3.10.
163
+ sig = inspect .signature (ops .dot_product_attention )
164
+ return "attn_logits_soft_cap" in sig .parameters
165
+ else :
166
+ return False
153
167
154
168
def _compute_attention (
155
169
self ,
@@ -166,7 +180,14 @@ def _compute_attention(
166
180
query_normalization = 1 / np .sqrt (
167
181
self .hidden_dim // self .num_query_heads
168
182
)
169
- if self ._can_use_flash_attention ():
183
+
184
+ if self .use_sliding_window_attention and attention_mask is not None :
185
+ attention_mask = self ._mask_sliding_window (
186
+ attention_mask ,
187
+ cache_update_index = cache_update_index ,
188
+ )
189
+
190
+ if self ._use_fused_attention_op ():
170
191
if attention_mask is not None :
171
192
attention_mask = ops .expand_dims (attention_mask , axis = 1 )
172
193
attention_mask = ops .cast (attention_mask , dtype = "bool" )
@@ -205,13 +226,8 @@ def _compute_attention(
205
226
ops .tanh (attention_logits ), self .logit_soft_cap
206
227
)
207
228
208
- if self .use_sliding_window_attention :
209
- attention_mask = self ._mask_sliding_window (
210
- attention_mask ,
211
- cache_update_index = cache_update_index ,
212
- )
213
-
214
- attention_mask = attention_mask [:, None , None , :, :]
229
+ if attention_mask is not None :
230
+ attention_mask = attention_mask [:, None , None , :, :]
215
231
orig_dtype = attention_logits .dtype
216
232
attention_softmax = self .softmax (attention_logits , mask = attention_mask )
217
233
attention_softmax = ops .cast (attention_softmax , orig_dtype )
@@ -256,6 +272,7 @@ def call(
256
272
attention_mask = None ,
257
273
cache = None ,
258
274
cache_update_index = 0 ,
275
+ cache_update_mask = None ,
259
276
training = False ,
260
277
):
261
278
query = self .query_dense (x )
@@ -275,7 +292,43 @@ def call(
275
292
276
293
key_update = self ._apply_rope (key_update , cache_update_index )
277
294
value_update = self .value_dense (x )
295
+
296
+ # Update cache. Note that the cache is updated only if the
297
+ # corresponding `cache_update_mask` value is True. This is to
298
+ # ensure that we don't update the cache at indices corresponding to
299
+ # the prompt. For Gemma3, in particular, this is useful because
300
+ # image tokens have bidirectional attention. During generation,
301
+ # if we have uneven inputs during generation, we might end up having
302
+ # causal attention between image tokens, which is incorrect. To
303
+ # avoid this, bidirectional attention is taken care of during
304
+ # the prefill step, and during generation, the cache is not updated
305
+ # for the prompt. The shape of `cache_update_mask` is
306
+ # `(bsz, seq_len)`, where `seq_len` is 1 when we are generating
307
+ # token-by-token.
278
308
start = [0 , cache_update_index , 0 , 0 ]
309
+ if cache_update_mask is not None :
310
+ cache_update_mask = ops .expand_dims (
311
+ ops .expand_dims (cache_update_mask , axis = - 1 ),
312
+ axis = - 1 ,
313
+ )
314
+ key_original = ops .slice (
315
+ key_cache , start , ops .shape (key_update )
316
+ )
317
+ value_original = ops .slice (
318
+ value_cache , start , ops .shape (value_update )
319
+ )
320
+
321
+ key_update = ops .where (
322
+ cache_update_mask ,
323
+ key_update ,
324
+ key_original ,
325
+ )
326
+ value_update = ops .where (
327
+ cache_update_mask ,
328
+ value_update ,
329
+ value_original ,
330
+ )
331
+
279
332
key = ops .slice_update (key_cache , start , key_update )
280
333
value = ops .slice_update (value_cache , start , value_update )
281
334
cache = ops .stack ((key , value ), axis = 1 )
0 commit comments