Skip to content

Commit d4e632d

Browse files
authored
Add Gemma3 [Text] (#2152)
1 parent d82fbec commit d4e632d

25 files changed

+3899
-0
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
5555
EfficientNetImageConverter,
5656
)
57+
from keras_hub.src.models.gemma3.gemma3_image_converter import (
58+
Gemma3ImageConverter,
59+
)
5760
from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
5861
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
5962
MobileNetImageConverter,

keras_hub/api/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@
177177
GemmaCausalLMPreprocessor,
178178
)
179179
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
180+
from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
181+
from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM
182+
from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import (
183+
Gemma3CausalLMPreprocessor,
184+
)
185+
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
180186
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
181187
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
182188
from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (

keras_hub/api/tokenizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer
2020
from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
2121
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
22+
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
2223
from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
2324
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
2425
from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
2+
from keras_hub.src.models.gemma3.gemma3_presets import backbone_presets
3+
from keras_hub.src.utils.preset_utils import register_presets
4+
5+
register_presets(backbone_presets, Gemma3Backbone)
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import inspect
2+
3+
import keras
4+
import numpy as np
5+
from keras import ops
6+
7+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8+
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
9+
from keras_hub.src.utils.keras_utils import clone_initializer
10+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
11+
from keras_hub.src.utils.keras_utils import running_on_tpu
12+
13+
14+
class CachedGemma3Attention(keras.layers.Layer):
15+
"""A cached grouped query attention layer for Gemma3.
16+
17+
This is different from Gemma and Gemma2 in several ways:
18+
19+
- `use_query_key_norm`: Applies RMS Norm on query, key.
20+
- `rope_wavelength`: RoPE wavelength differs from local to global attention
21+
layers.
22+
- `rope_scaling_factor`: RoPE scaling factor differs from local to global
23+
attention layers.
24+
"""
25+
26+
def __init__(
27+
self,
28+
head_dim,
29+
num_query_heads,
30+
num_key_value_heads,
31+
kernel_initializer="glorot_uniform",
32+
logit_soft_cap=None,
33+
use_sliding_window_attention=False,
34+
sliding_window_size=4096,
35+
query_head_dim_normalize=True,
36+
use_query_key_norm=False,
37+
layer_norm_epsilon=1e-6,
38+
rope_wavelength=10_000.0,
39+
rope_scaling_factor=1.0,
40+
dropout=0,
41+
**kwargs,
42+
):
43+
super().__init__(**kwargs)
44+
self.num_query_heads = num_query_heads
45+
self.num_key_value_heads = num_key_value_heads
46+
self.head_dim = head_dim
47+
self.logit_soft_cap = logit_soft_cap
48+
self.use_sliding_window_attention = use_sliding_window_attention
49+
self.sliding_window_size = sliding_window_size
50+
self.query_head_dim_normalize = query_head_dim_normalize
51+
self.use_query_key_norm = use_query_key_norm
52+
self.layer_norm_epsilon = layer_norm_epsilon
53+
self.rope_wavelength = rope_wavelength
54+
self.rope_scaling_factor = rope_scaling_factor
55+
self.dropout = dropout
56+
57+
self._kernel_initializer = keras.initializers.get(
58+
clone_initializer(kernel_initializer)
59+
)
60+
self.num_key_value_groups = num_query_heads // num_key_value_heads
61+
self.query_head_dim_normalize = query_head_dim_normalize
62+
63+
def build(self, inputs_shape):
64+
self.hidden_dim = inputs_shape[-1]
65+
66+
self.query_dense = keras.layers.EinsumDense(
67+
"btd,ndh->btnh",
68+
output_shape=(None, self.num_query_heads, self.head_dim),
69+
kernel_initializer=self._kernel_initializer,
70+
dtype=self.dtype_policy,
71+
name="query",
72+
)
73+
self.query_dense.build(inputs_shape)
74+
75+
self.key_dense = keras.layers.EinsumDense(
76+
"bsd,kdh->bskh",
77+
output_shape=(None, self.num_key_value_heads, self.head_dim),
78+
kernel_initializer=self._kernel_initializer,
79+
dtype=self.dtype_policy,
80+
name="key",
81+
)
82+
self.key_dense.build(inputs_shape)
83+
84+
self.value_dense = keras.layers.EinsumDense(
85+
"bsd,kdh->bskh",
86+
output_shape=(None, self.num_key_value_heads, self.head_dim),
87+
kernel_initializer=self._kernel_initializer,
88+
dtype=self.dtype_policy,
89+
name="value",
90+
)
91+
self.value_dense.build(inputs_shape)
92+
93+
if self.use_query_key_norm:
94+
self.query_norm = RMSNormalization(
95+
epsilon=self.layer_norm_epsilon,
96+
dtype=self.dtype_policy,
97+
name="query_norm",
98+
)
99+
self.query_norm.build(
100+
self.query_dense.compute_output_shape(inputs_shape)
101+
)
102+
103+
self.key_norm = RMSNormalization(
104+
epsilon=self.layer_norm_epsilon,
105+
dtype=self.dtype_policy,
106+
name="key_norm",
107+
)
108+
self.key_norm.build(
109+
self.key_dense.compute_output_shape(inputs_shape)
110+
)
111+
112+
self.dropout_layer = keras.layers.Dropout(
113+
rate=self.dropout,
114+
dtype=self.dtype_policy,
115+
)
116+
117+
self.output_dense = keras.layers.EinsumDense(
118+
equation="btnh,nhd->btd",
119+
output_shape=(None, self.hidden_dim),
120+
kernel_initializer=self._kernel_initializer,
121+
dtype=self.dtype_policy,
122+
name="attention_output",
123+
)
124+
self.output_dense.build(
125+
(None, None, self.num_query_heads, self.head_dim)
126+
)
127+
self.softmax = keras.layers.Softmax(dtype="float32")
128+
129+
self.rope_layer = RotaryEmbedding(
130+
max_wavelength=self.rope_wavelength,
131+
scaling_factor=self.rope_scaling_factor,
132+
dtype=self.dtype_policy,
133+
)
134+
135+
self.built = True
136+
137+
def _apply_rope(self, x, start_index):
138+
"""Rope rotate q or k."""
139+
x = self.rope_layer(x, start_index=start_index)
140+
return x
141+
142+
def _can_use_flash_attention(self):
143+
if not has_flash_attention_support():
144+
return False
145+
if self.dropout > 0.0:
146+
return False
147+
if self.logit_soft_cap is None:
148+
return True
149+
sig = inspect.signature(ops.dot_product_attention)
150+
# We can currently only run soft capped attention for keras >= 3.10
151+
# and only on TPU.
152+
return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
153+
154+
def _compute_attention(
155+
self,
156+
q,
157+
k,
158+
v,
159+
attention_mask,
160+
training=False,
161+
cache_update_index=0,
162+
):
163+
if self.query_head_dim_normalize:
164+
query_normalization = 1 / np.sqrt(self.head_dim)
165+
else:
166+
query_normalization = 1 / np.sqrt(
167+
self.hidden_dim // self.num_query_heads
168+
)
169+
if self._can_use_flash_attention():
170+
if attention_mask is not None:
171+
attention_mask = ops.expand_dims(attention_mask, axis=1)
172+
attention_mask = ops.cast(attention_mask, dtype="bool")
173+
# Only pass soft cap if needed as not all keras versions support.
174+
if self.logit_soft_cap:
175+
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
176+
else:
177+
kwargs = {}
178+
return ops.dot_product_attention(
179+
query=q,
180+
key=k,
181+
value=v,
182+
mask=attention_mask,
183+
scale=query_normalization,
184+
**kwargs,
185+
)
186+
187+
q *= ops.cast(query_normalization, dtype=q.dtype)
188+
q_shape = ops.shape(q)
189+
q = ops.reshape(
190+
q,
191+
(
192+
*q_shape[:-2],
193+
self.num_key_value_heads,
194+
self.num_query_heads // self.num_key_value_heads,
195+
q_shape[-1],
196+
),
197+
)
198+
b, q_len, _, _, h = ops.shape(q)
199+
200+
# Fallback to standard attention if flash attention is disabled
201+
attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
202+
if self.logit_soft_cap is not None:
203+
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
204+
attention_logits = ops.multiply(
205+
ops.tanh(attention_logits), self.logit_soft_cap
206+
)
207+
208+
if self.use_sliding_window_attention:
209+
attention_mask = self._mask_sliding_window(
210+
attention_mask,
211+
cache_update_index=cache_update_index,
212+
)
213+
214+
attention_mask = attention_mask[:, None, None, :, :]
215+
orig_dtype = attention_logits.dtype
216+
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
217+
attention_softmax = ops.cast(attention_softmax, orig_dtype)
218+
219+
if self.dropout:
220+
attention_softmax = self.dropout_layer(
221+
attention_softmax, training=training
222+
)
223+
224+
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
225+
return ops.reshape(results, (b, q_len, self.num_query_heads, h))
226+
227+
def _mask_sliding_window(
228+
self,
229+
attention_mask,
230+
cache_update_index=0,
231+
):
232+
batch_size, query_len, key_len = ops.shape(attention_mask)
233+
# Compute the sliding window for square attention.
234+
all_ones = ops.ones((key_len, key_len), "bool")
235+
if keras.config.backend() == "tensorflow":
236+
# TODO: trui/tril has issues with dynamic shape on the tensorflow
237+
# backend. We should fix, but use `band_part` for now.
238+
import tensorflow as tf
239+
240+
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
241+
band_size = ops.cast(band_size, "int32")
242+
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
243+
else:
244+
sliding_mask = ops.triu(
245+
all_ones, -1 * self.sliding_window_size + 1
246+
) * ops.tril(all_ones, self.sliding_window_size - 1)
247+
# Slice the window for short queries during generation.
248+
start = (cache_update_index, 0)
249+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
250+
sliding_mask = ops.expand_dims(sliding_mask, 0)
251+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
252+
253+
def call(
254+
self,
255+
x,
256+
attention_mask=None,
257+
cache=None,
258+
cache_update_index=0,
259+
training=False,
260+
):
261+
query = self.query_dense(x)
262+
263+
if self.use_query_key_norm:
264+
query = self.query_norm(query)
265+
266+
query = self._apply_rope(query, cache_update_index)
267+
268+
if cache is not None:
269+
key_cache = cache[:, 0, ...]
270+
value_cache = cache[:, 1, ...]
271+
key_update = self.key_dense(x)
272+
273+
if self.use_query_key_norm:
274+
key_update = self.key_norm(key_update)
275+
276+
key_update = self._apply_rope(key_update, cache_update_index)
277+
value_update = self.value_dense(x)
278+
start = [0, cache_update_index, 0, 0]
279+
key = ops.slice_update(key_cache, start, key_update)
280+
value = ops.slice_update(value_cache, start, value_update)
281+
cache = ops.stack((key, value), axis=1)
282+
else:
283+
key = self.key_dense(x)
284+
285+
if self.use_query_key_norm:
286+
key = self.key_norm(key)
287+
288+
key = self._apply_rope(key, cache_update_index)
289+
value = self.value_dense(x)
290+
291+
attention_vec = self._compute_attention(
292+
query,
293+
key,
294+
value,
295+
attention_mask,
296+
training=training,
297+
cache_update_index=cache_update_index,
298+
)
299+
300+
# Wipe attn vec if there are no attended tokens.
301+
no_attended_tokens = ops.all(
302+
ops.equal(attention_mask, 0), axis=-1, keepdims=True
303+
)[..., None]
304+
attention_vec = ops.where(
305+
no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
306+
)
307+
308+
attention_output = self.output_dense(attention_vec)
309+
310+
if cache is not None:
311+
return attention_output, cache
312+
return attention_output
313+
314+
def compute_output_shape(self, input_shape):
315+
return input_shape

0 commit comments

Comments
 (0)