Skip to content

Commit 7209d67

Browse files
Update the int8 quant logic of ReversibleEmbedding. (#2250)
1 parent c00db4e commit 7209d67

File tree

1 file changed

+25
-35
lines changed

1 file changed

+25
-35
lines changed

keras_hub/src/layers/modeling/reversible_embedding.py

+25-35
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
13
import keras
24
from keras import ops
35

@@ -184,31 +186,33 @@ def quantized_call(self, inputs, reverse=False):
184186
else:
185187
self._quantization_mode_error(self.quantization_mode)
186188

187-
def _int8_build(
188-
self,
189-
embeddings_initializer="zeros",
190-
embeddings_scale_initializer="ones",
191-
reverse_embeddings_initializer="zeros",
192-
reverse_embeddings_scale_initializer="ones",
193-
):
194-
super()._int8_build(
195-
embeddings_initializer, embeddings_scale_initializer
196-
)
189+
def _int8_build(self, embeddings_shape=None):
190+
if (
191+
"embeddings_shape"
192+
in inspect.signature(super()._int8_build).parameters
193+
):
194+
if embeddings_shape is None:
195+
embeddings_shape = (self.input_dim, self.output_dim)
196+
super()._int8_build(embeddings_shape=embeddings_shape)
197+
else:
198+
# Backward compatibility for older versions of Keras.
199+
super()._int8_build()
197200
self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
198201
if not self.tie_weights:
199202
self.reverse_embeddings = self.add_weight(
200203
name="reverse_embeddings",
201204
shape=(self.output_dim, self.input_dim),
202-
initializer=reverse_embeddings_initializer,
205+
initializer="zeros",
203206
dtype="int8",
204207
trainable=False,
205208
)
206209
self.reverse_embeddings_scale = self.add_weight(
207210
name="reverse_embeddings_scale",
208211
shape=(self.input_dim,),
209-
initializer=reverse_embeddings_scale_initializer,
212+
initializer="ones",
210213
trainable=False,
211214
)
215+
self._is_quantized = True
212216

213217
def _int8_call(self, inputs, reverse=False):
214218
if reverse:
@@ -232,27 +236,20 @@ def _int8_call(self, inputs, reverse=False):
232236
return super()._int8_call(inputs)
233237

234238
def quantize(self, mode, type_check=True):
235-
import gc
236-
237239
if type_check and type(self) is not ReversibleEmbedding:
238-
raise NotImplementedError(
239-
f"Layer {self.__class__.__name__} does not have a `quantize()` "
240-
"method implemented."
241-
)
242-
self._check_quantize_args(mode, self.compute_dtype)
240+
raise self._not_implemented_error(self.quantize)
243241

244242
def abs_max_quantize(inputs, axis):
245243
return keras.quantizers.abs_max_quantize(
246244
inputs, axis=axis, to_numpy=True
247245
)
248246

249-
self._tracker.unlock()
247+
embeddings_shape = (self.input_dim, self.output_dim)
250248
if mode == "int8":
251249
embeddings, embeddings_scale = abs_max_quantize(
252250
self._embeddings, axis=-1
253251
)
254252
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
255-
self._untrack_variable(self._embeddings)
256253
del self._embeddings
257254
if not self.tie_weights:
258255
reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
@@ -261,24 +258,17 @@ def abs_max_quantize(inputs, axis):
261258
reverse_embeddings_scale = ops.squeeze(
262259
reverse_embeddings_scale, axis=0
263260
)
264-
self._untrack_variable(self.reverse_embeddings)
265261
del self.reverse_embeddings
266-
else:
267-
reverse_embeddings = None
268-
reverse_embeddings_scale = None
269-
self._int8_build(
270-
lambda shape, dtype: embeddings,
271-
lambda shape, dtype: embeddings_scale,
272-
lambda shape, dtype: reverse_embeddings,
273-
lambda shape, dtype: reverse_embeddings_scale,
274-
)
275-
else:
276-
raise self._quantization_mode_error(mode)
277-
self._tracker.lock()
262+
self.quantized_build(embeddings_shape, mode)
263+
if mode == "int8":
264+
self._embeddings.assign(embeddings)
265+
self.embeddings_scale.assign(embeddings_scale)
266+
if not self.tie_weights:
267+
self.reverse_embeddings.assign(reverse_embeddings)
268+
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
278269

279270
if self.dtype_policy.quantization_mode is None:
280271
policy = keras.dtype_policies.get(
281272
f"{mode}_from_{self.dtype_policy.name}"
282273
)
283274
self.dtype_policy = policy
284-
gc.collect()

0 commit comments

Comments
 (0)