|
| 1 | +import inspect |
| 2 | + |
| 3 | +import keras |
| 4 | +import numpy as np |
| 5 | +from keras import ops |
| 6 | + |
| 7 | +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding |
| 8 | +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization |
| 9 | +from keras_hub.src.utils.keras_utils import clone_initializer |
| 10 | +from keras_hub.src.utils.keras_utils import has_flash_attention_support |
| 11 | +from keras_hub.src.utils.keras_utils import running_on_tpu |
| 12 | + |
| 13 | + |
| 14 | +class CachedGemma3Attention(keras.layers.Layer): |
| 15 | + """A cached grouped query attention layer for Gemma3. |
| 16 | +
|
| 17 | + This is different from Gemma and Gemma2 in several ways: |
| 18 | +
|
| 19 | + - `use_query_key_norm`: Applies RMS Norm on query, key. |
| 20 | + - `rope_wavelength`: RoPE wavelength differs from local to global attention |
| 21 | + layers. |
| 22 | + - `rope_scaling_factor`: RoPE scaling factor differs from local to global |
| 23 | + attention layers. |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + head_dim, |
| 29 | + num_query_heads, |
| 30 | + num_key_value_heads, |
| 31 | + kernel_initializer="glorot_uniform", |
| 32 | + logit_soft_cap=None, |
| 33 | + use_sliding_window_attention=False, |
| 34 | + sliding_window_size=4096, |
| 35 | + query_head_dim_normalize=True, |
| 36 | + use_query_key_norm=False, |
| 37 | + layer_norm_epsilon=1e-6, |
| 38 | + rope_wavelength=10_000.0, |
| 39 | + rope_scaling_factor=1.0, |
| 40 | + dropout=0, |
| 41 | + **kwargs, |
| 42 | + ): |
| 43 | + super().__init__(**kwargs) |
| 44 | + self.num_query_heads = num_query_heads |
| 45 | + self.num_key_value_heads = num_key_value_heads |
| 46 | + self.head_dim = head_dim |
| 47 | + self.logit_soft_cap = logit_soft_cap |
| 48 | + self.use_sliding_window_attention = use_sliding_window_attention |
| 49 | + self.sliding_window_size = sliding_window_size |
| 50 | + self.query_head_dim_normalize = query_head_dim_normalize |
| 51 | + self.use_query_key_norm = use_query_key_norm |
| 52 | + self.layer_norm_epsilon = layer_norm_epsilon |
| 53 | + self.rope_wavelength = rope_wavelength |
| 54 | + self.rope_scaling_factor = rope_scaling_factor |
| 55 | + self.dropout = dropout |
| 56 | + |
| 57 | + self._kernel_initializer = keras.initializers.get( |
| 58 | + clone_initializer(kernel_initializer) |
| 59 | + ) |
| 60 | + self.num_key_value_groups = num_query_heads // num_key_value_heads |
| 61 | + self.query_head_dim_normalize = query_head_dim_normalize |
| 62 | + |
| 63 | + def build(self, inputs_shape): |
| 64 | + self.hidden_dim = inputs_shape[-1] |
| 65 | + |
| 66 | + self.query_dense = keras.layers.EinsumDense( |
| 67 | + "btd,ndh->btnh", |
| 68 | + output_shape=(None, self.num_query_heads, self.head_dim), |
| 69 | + kernel_initializer=self._kernel_initializer, |
| 70 | + dtype=self.dtype_policy, |
| 71 | + name="query", |
| 72 | + ) |
| 73 | + self.query_dense.build(inputs_shape) |
| 74 | + |
| 75 | + self.key_dense = keras.layers.EinsumDense( |
| 76 | + "bsd,kdh->bskh", |
| 77 | + output_shape=(None, self.num_key_value_heads, self.head_dim), |
| 78 | + kernel_initializer=self._kernel_initializer, |
| 79 | + dtype=self.dtype_policy, |
| 80 | + name="key", |
| 81 | + ) |
| 82 | + self.key_dense.build(inputs_shape) |
| 83 | + |
| 84 | + self.value_dense = keras.layers.EinsumDense( |
| 85 | + "bsd,kdh->bskh", |
| 86 | + output_shape=(None, self.num_key_value_heads, self.head_dim), |
| 87 | + kernel_initializer=self._kernel_initializer, |
| 88 | + dtype=self.dtype_policy, |
| 89 | + name="value", |
| 90 | + ) |
| 91 | + self.value_dense.build(inputs_shape) |
| 92 | + |
| 93 | + if self.use_query_key_norm: |
| 94 | + self.query_norm = RMSNormalization( |
| 95 | + epsilon=self.layer_norm_epsilon, |
| 96 | + dtype=self.dtype_policy, |
| 97 | + name="query_norm", |
| 98 | + ) |
| 99 | + self.query_norm.build( |
| 100 | + self.query_dense.compute_output_shape(inputs_shape) |
| 101 | + ) |
| 102 | + |
| 103 | + self.key_norm = RMSNormalization( |
| 104 | + epsilon=self.layer_norm_epsilon, |
| 105 | + dtype=self.dtype_policy, |
| 106 | + name="key_norm", |
| 107 | + ) |
| 108 | + self.key_norm.build( |
| 109 | + self.key_dense.compute_output_shape(inputs_shape) |
| 110 | + ) |
| 111 | + |
| 112 | + self.dropout_layer = keras.layers.Dropout( |
| 113 | + rate=self.dropout, |
| 114 | + dtype=self.dtype_policy, |
| 115 | + ) |
| 116 | + |
| 117 | + self.output_dense = keras.layers.EinsumDense( |
| 118 | + equation="btnh,nhd->btd", |
| 119 | + output_shape=(None, self.hidden_dim), |
| 120 | + kernel_initializer=self._kernel_initializer, |
| 121 | + dtype=self.dtype_policy, |
| 122 | + name="attention_output", |
| 123 | + ) |
| 124 | + self.output_dense.build( |
| 125 | + (None, None, self.num_query_heads, self.head_dim) |
| 126 | + ) |
| 127 | + self.softmax = keras.layers.Softmax(dtype="float32") |
| 128 | + |
| 129 | + self.rope_layer = RotaryEmbedding( |
| 130 | + max_wavelength=self.rope_wavelength, |
| 131 | + scaling_factor=self.rope_scaling_factor, |
| 132 | + dtype=self.dtype_policy, |
| 133 | + ) |
| 134 | + |
| 135 | + self.built = True |
| 136 | + |
| 137 | + def _apply_rope(self, x, start_index): |
| 138 | + """Rope rotate q or k.""" |
| 139 | + x = self.rope_layer(x, start_index=start_index) |
| 140 | + return x |
| 141 | + |
| 142 | + def _can_use_flash_attention(self): |
| 143 | + if not has_flash_attention_support(): |
| 144 | + return False |
| 145 | + if self.dropout > 0.0: |
| 146 | + return False |
| 147 | + if self.logit_soft_cap is None: |
| 148 | + return True |
| 149 | + sig = inspect.signature(ops.dot_product_attention) |
| 150 | + # We can currently only run soft capped attention for keras >= 3.10 |
| 151 | + # and only on TPU. |
| 152 | + return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters |
| 153 | + |
| 154 | + def _compute_attention( |
| 155 | + self, |
| 156 | + q, |
| 157 | + k, |
| 158 | + v, |
| 159 | + attention_mask, |
| 160 | + training=False, |
| 161 | + cache_update_index=0, |
| 162 | + ): |
| 163 | + if self.query_head_dim_normalize: |
| 164 | + query_normalization = 1 / np.sqrt(self.head_dim) |
| 165 | + else: |
| 166 | + query_normalization = 1 / np.sqrt( |
| 167 | + self.hidden_dim // self.num_query_heads |
| 168 | + ) |
| 169 | + if self._can_use_flash_attention(): |
| 170 | + if attention_mask is not None: |
| 171 | + attention_mask = ops.expand_dims(attention_mask, axis=1) |
| 172 | + attention_mask = ops.cast(attention_mask, dtype="bool") |
| 173 | + # Only pass soft cap if needed as not all keras versions support. |
| 174 | + if self.logit_soft_cap: |
| 175 | + kwargs = {"attn_logits_soft_cap": self.logit_soft_cap} |
| 176 | + else: |
| 177 | + kwargs = {} |
| 178 | + return ops.dot_product_attention( |
| 179 | + query=q, |
| 180 | + key=k, |
| 181 | + value=v, |
| 182 | + mask=attention_mask, |
| 183 | + scale=query_normalization, |
| 184 | + **kwargs, |
| 185 | + ) |
| 186 | + |
| 187 | + q *= ops.cast(query_normalization, dtype=q.dtype) |
| 188 | + q_shape = ops.shape(q) |
| 189 | + q = ops.reshape( |
| 190 | + q, |
| 191 | + ( |
| 192 | + *q_shape[:-2], |
| 193 | + self.num_key_value_heads, |
| 194 | + self.num_query_heads // self.num_key_value_heads, |
| 195 | + q_shape[-1], |
| 196 | + ), |
| 197 | + ) |
| 198 | + b, q_len, _, _, h = ops.shape(q) |
| 199 | + |
| 200 | + # Fallback to standard attention if flash attention is disabled |
| 201 | + attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k) |
| 202 | + if self.logit_soft_cap is not None: |
| 203 | + attention_logits = ops.divide(attention_logits, self.logit_soft_cap) |
| 204 | + attention_logits = ops.multiply( |
| 205 | + ops.tanh(attention_logits), self.logit_soft_cap |
| 206 | + ) |
| 207 | + |
| 208 | + if self.use_sliding_window_attention: |
| 209 | + attention_mask = self._mask_sliding_window( |
| 210 | + attention_mask, |
| 211 | + cache_update_index=cache_update_index, |
| 212 | + ) |
| 213 | + |
| 214 | + attention_mask = attention_mask[:, None, None, :, :] |
| 215 | + orig_dtype = attention_logits.dtype |
| 216 | + attention_softmax = self.softmax(attention_logits, mask=attention_mask) |
| 217 | + attention_softmax = ops.cast(attention_softmax, orig_dtype) |
| 218 | + |
| 219 | + if self.dropout: |
| 220 | + attention_softmax = self.dropout_layer( |
| 221 | + attention_softmax, training=training |
| 222 | + ) |
| 223 | + |
| 224 | + results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v) |
| 225 | + return ops.reshape(results, (b, q_len, self.num_query_heads, h)) |
| 226 | + |
| 227 | + def _mask_sliding_window( |
| 228 | + self, |
| 229 | + attention_mask, |
| 230 | + cache_update_index=0, |
| 231 | + ): |
| 232 | + batch_size, query_len, key_len = ops.shape(attention_mask) |
| 233 | + # Compute the sliding window for square attention. |
| 234 | + all_ones = ops.ones((key_len, key_len), "bool") |
| 235 | + if keras.config.backend() == "tensorflow": |
| 236 | + # TODO: trui/tril has issues with dynamic shape on the tensorflow |
| 237 | + # backend. We should fix, but use `band_part` for now. |
| 238 | + import tensorflow as tf |
| 239 | + |
| 240 | + band_size = ops.minimum(key_len, self.sliding_window_size - 1) |
| 241 | + band_size = ops.cast(band_size, "int32") |
| 242 | + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) |
| 243 | + else: |
| 244 | + sliding_mask = ops.triu( |
| 245 | + all_ones, -1 * self.sliding_window_size + 1 |
| 246 | + ) * ops.tril(all_ones, self.sliding_window_size - 1) |
| 247 | + # Slice the window for short queries during generation. |
| 248 | + start = (cache_update_index, 0) |
| 249 | + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) |
| 250 | + sliding_mask = ops.expand_dims(sliding_mask, 0) |
| 251 | + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) |
| 252 | + |
| 253 | + def call( |
| 254 | + self, |
| 255 | + x, |
| 256 | + attention_mask=None, |
| 257 | + cache=None, |
| 258 | + cache_update_index=0, |
| 259 | + training=False, |
| 260 | + ): |
| 261 | + query = self.query_dense(x) |
| 262 | + |
| 263 | + if self.use_query_key_norm: |
| 264 | + query = self.query_norm(query) |
| 265 | + |
| 266 | + query = self._apply_rope(query, cache_update_index) |
| 267 | + |
| 268 | + if cache is not None: |
| 269 | + key_cache = cache[:, 0, ...] |
| 270 | + value_cache = cache[:, 1, ...] |
| 271 | + key_update = self.key_dense(x) |
| 272 | + |
| 273 | + if self.use_query_key_norm: |
| 274 | + key_update = self.key_norm(key_update) |
| 275 | + |
| 276 | + key_update = self._apply_rope(key_update, cache_update_index) |
| 277 | + value_update = self.value_dense(x) |
| 278 | + start = [0, cache_update_index, 0, 0] |
| 279 | + key = ops.slice_update(key_cache, start, key_update) |
| 280 | + value = ops.slice_update(value_cache, start, value_update) |
| 281 | + cache = ops.stack((key, value), axis=1) |
| 282 | + else: |
| 283 | + key = self.key_dense(x) |
| 284 | + |
| 285 | + if self.use_query_key_norm: |
| 286 | + key = self.key_norm(key) |
| 287 | + |
| 288 | + key = self._apply_rope(key, cache_update_index) |
| 289 | + value = self.value_dense(x) |
| 290 | + |
| 291 | + attention_vec = self._compute_attention( |
| 292 | + query, |
| 293 | + key, |
| 294 | + value, |
| 295 | + attention_mask, |
| 296 | + training=training, |
| 297 | + cache_update_index=cache_update_index, |
| 298 | + ) |
| 299 | + |
| 300 | + # Wipe attn vec if there are no attended tokens. |
| 301 | + no_attended_tokens = ops.all( |
| 302 | + ops.equal(attention_mask, 0), axis=-1, keepdims=True |
| 303 | + )[..., None] |
| 304 | + attention_vec = ops.where( |
| 305 | + no_attended_tokens, ops.zeros_like(attention_vec), attention_vec |
| 306 | + ) |
| 307 | + |
| 308 | + attention_output = self.output_dense(attention_vec) |
| 309 | + |
| 310 | + if cache is not None: |
| 311 | + return attention_output, cache |
| 312 | + return attention_output |
| 313 | + |
| 314 | + def compute_output_shape(self, input_shape): |
| 315 | + return input_shape |
0 commit comments