Skip to content

Commit 7bd7b57

Browse files
authored
Add vision for Gemma3 (keras-team#2170)
1 parent 87fca10 commit 7bd7b57

19 files changed

+1580
-900
lines changed

keras_hub/api/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@
183183
Gemma3CausalLMPreprocessor,
184184
)
185185
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
186+
from keras_hub.src.models.gemma3.gemma3_vision_encoder import (
187+
Gemma3VisionEncoder,
188+
)
186189
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
187190
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
188191
from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (

keras_hub/src/models/gemma3/gemma3_attention.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,28 @@
88
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
99
from keras_hub.src.utils.keras_utils import clone_initializer
1010
from keras_hub.src.utils.keras_utils import fused_attention_op_available
11+
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
12+
from keras_hub.src.utils.keras_utils import running_on_gpu
1113
from keras_hub.src.utils.keras_utils import running_on_tpu
1214

1315

1416
class CachedGemma3Attention(keras.layers.Layer):
1517
"""A cached grouped query attention layer for Gemma3.
1618
17-
This is different from Gemma and Gemma2 in several ways:
19+
This is the same as the attention layer used for Gemma and Gemma2. It
20+
exposes a few additional args:
1821
19-
- `use_query_key_norm`: Applies RMS Norm on query, key.
20-
- `rope_wavelength`: RoPE wavelength differs from local to global attention
21-
layers.
22-
- `rope_scaling_factor`: RoPE scaling factor differs from local to global
23-
attention layers.
22+
`use_query_key_norm`: bool. If True, apply RMS normalization on query
23+
and key. For Gemma3, this is True.
24+
`rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
25+
uses 10K for local attention layers and 1M for global attention layers.
26+
`gate_dim_reduction`: int. In the gating layers, the output dimension is
27+
`intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
28+
value is 2. For Gemma3, it is 1.
29+
30+
Moreover, the call() method takes in a `cache_update_mask` so as to make
31+
sure that the key-value cache is updated only for the non-prompt tokens
32+
during generation.
2433
"""
2534

2635
def __init__(
@@ -139,17 +148,22 @@ def _apply_rope(self, x, start_index):
139148
x = self.rope_layer(x, start_index=start_index)
140149
return x
141150

142-
def _can_use_flash_attention(self):
151+
def _use_fused_attention_op(self):
143152
if not fused_attention_op_available():
144153
return False
145154
if self.dropout > 0.0:
146155
return False
147-
if self.logit_soft_cap is None:
148-
return True
149-
sig = inspect.signature(ops.dot_product_attention)
150-
# We can currently only run soft capped attention for keras >= 3.10
151-
# and only on TPU.
152-
return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
156+
if running_on_gpu():
157+
# GPU never supports softcap in the fused op.
158+
if self.logit_soft_cap is not None:
159+
return False
160+
return gpu_supports_fused_attention_op()
161+
elif running_on_tpu():
162+
# TPU supports softcap with on keras >= 3.10.
163+
sig = inspect.signature(ops.dot_product_attention)
164+
return "attn_logits_soft_cap" in sig.parameters
165+
else:
166+
return False
153167

154168
def _compute_attention(
155169
self,
@@ -166,7 +180,14 @@ def _compute_attention(
166180
query_normalization = 1 / np.sqrt(
167181
self.hidden_dim // self.num_query_heads
168182
)
169-
if self._can_use_flash_attention():
183+
184+
if self.use_sliding_window_attention and attention_mask is not None:
185+
attention_mask = self._mask_sliding_window(
186+
attention_mask,
187+
cache_update_index=cache_update_index,
188+
)
189+
190+
if self._use_fused_attention_op():
170191
if attention_mask is not None:
171192
attention_mask = ops.expand_dims(attention_mask, axis=1)
172193
attention_mask = ops.cast(attention_mask, dtype="bool")
@@ -205,13 +226,8 @@ def _compute_attention(
205226
ops.tanh(attention_logits), self.logit_soft_cap
206227
)
207228

208-
if self.use_sliding_window_attention:
209-
attention_mask = self._mask_sliding_window(
210-
attention_mask,
211-
cache_update_index=cache_update_index,
212-
)
213-
214-
attention_mask = attention_mask[:, None, None, :, :]
229+
if attention_mask is not None:
230+
attention_mask = attention_mask[:, None, None, :, :]
215231
orig_dtype = attention_logits.dtype
216232
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
217233
attention_softmax = ops.cast(attention_softmax, orig_dtype)
@@ -256,6 +272,7 @@ def call(
256272
attention_mask=None,
257273
cache=None,
258274
cache_update_index=0,
275+
cache_update_mask=None,
259276
training=False,
260277
):
261278
query = self.query_dense(x)
@@ -275,7 +292,43 @@ def call(
275292

276293
key_update = self._apply_rope(key_update, cache_update_index)
277294
value_update = self.value_dense(x)
295+
296+
# Update cache. Note that the cache is updated only if the
297+
# corresponding `cache_update_mask` value is True. This is to
298+
# ensure that we don't update the cache at indices corresponding to
299+
# the prompt. For Gemma3, in particular, this is useful because
300+
# image tokens have bidirectional attention. During generation,
301+
# if we have uneven inputs during generation, we might end up having
302+
# causal attention between image tokens, which is incorrect. To
303+
# avoid this, bidirectional attention is taken care of during
304+
# the prefill step, and during generation, the cache is not updated
305+
# for the prompt. The shape of `cache_update_mask` is
306+
# `(bsz, seq_len)`, where `seq_len` is 1 when we are generating
307+
# token-by-token.
278308
start = [0, cache_update_index, 0, 0]
309+
if cache_update_mask is not None:
310+
cache_update_mask = ops.expand_dims(
311+
ops.expand_dims(cache_update_mask, axis=-1),
312+
axis=-1,
313+
)
314+
key_original = ops.slice(
315+
key_cache, start, ops.shape(key_update)
316+
)
317+
value_original = ops.slice(
318+
value_cache, start, ops.shape(value_update)
319+
)
320+
321+
key_update = ops.where(
322+
cache_update_mask,
323+
key_update,
324+
key_original,
325+
)
326+
value_update = ops.where(
327+
cache_update_mask,
328+
value_update,
329+
value_original,
330+
)
331+
279332
key = ops.slice_update(key_cache, start, key_update)
280333
value = ops.slice_update(value_cache, start, value_update)
281334
cache = ops.stack((key, value), axis=1)

0 commit comments

Comments
 (0)