1
+ import inspect
2
+
1
3
import keras
2
4
from keras import ops
3
5
@@ -184,31 +186,33 @@ def quantized_call(self, inputs, reverse=False):
184
186
else :
185
187
self ._quantization_mode_error (self .quantization_mode )
186
188
187
- def _int8_build (
188
- self ,
189
- embeddings_initializer = "zeros" ,
190
- embeddings_scale_initializer = "ones" ,
191
- reverse_embeddings_initializer = "zeros" ,
192
- reverse_embeddings_scale_initializer = "ones" ,
193
- ):
194
- super ()._int8_build (
195
- embeddings_initializer , embeddings_scale_initializer
196
- )
189
+ def _int8_build (self , embeddings_shape = None ):
190
+ if (
191
+ "embeddings_shape"
192
+ in inspect .signature (super ()._int8_build ).parameters
193
+ ):
194
+ if embeddings_shape is None :
195
+ embeddings_shape = (self .input_dim , self .output_dim )
196
+ super ()._int8_build (embeddings_shape = embeddings_shape )
197
+ else :
198
+ # Backward compatibility for older versions of Keras.
199
+ super ()._int8_build ()
197
200
self .inputs_quantizer = keras .quantizers .AbsMaxQuantizer (axis = - 1 )
198
201
if not self .tie_weights :
199
202
self .reverse_embeddings = self .add_weight (
200
203
name = "reverse_embeddings" ,
201
204
shape = (self .output_dim , self .input_dim ),
202
- initializer = reverse_embeddings_initializer ,
205
+ initializer = "zeros" ,
203
206
dtype = "int8" ,
204
207
trainable = False ,
205
208
)
206
209
self .reverse_embeddings_scale = self .add_weight (
207
210
name = "reverse_embeddings_scale" ,
208
211
shape = (self .input_dim ,),
209
- initializer = reverse_embeddings_scale_initializer ,
212
+ initializer = "ones" ,
210
213
trainable = False ,
211
214
)
215
+ self ._is_quantized = True
212
216
213
217
def _int8_call (self , inputs , reverse = False ):
214
218
if reverse :
@@ -232,27 +236,20 @@ def _int8_call(self, inputs, reverse=False):
232
236
return super ()._int8_call (inputs )
233
237
234
238
def quantize (self , mode , type_check = True ):
235
- import gc
236
-
237
239
if type_check and type (self ) is not ReversibleEmbedding :
238
- raise NotImplementedError (
239
- f"Layer { self .__class__ .__name__ } does not have a `quantize()` "
240
- "method implemented."
241
- )
242
- self ._check_quantize_args (mode , self .compute_dtype )
240
+ raise self ._not_implemented_error (self .quantize )
243
241
244
242
def abs_max_quantize (inputs , axis ):
245
243
return keras .quantizers .abs_max_quantize (
246
244
inputs , axis = axis , to_numpy = True
247
245
)
248
246
249
- self ._tracker . unlock ( )
247
+ embeddings_shape = ( self .input_dim , self . output_dim )
250
248
if mode == "int8" :
251
249
embeddings , embeddings_scale = abs_max_quantize (
252
250
self ._embeddings , axis = - 1
253
251
)
254
252
embeddings_scale = ops .squeeze (embeddings_scale , axis = - 1 )
255
- self ._untrack_variable (self ._embeddings )
256
253
del self ._embeddings
257
254
if not self .tie_weights :
258
255
reverse_embeddings , reverse_embeddings_scale = abs_max_quantize (
@@ -261,24 +258,17 @@ def abs_max_quantize(inputs, axis):
261
258
reverse_embeddings_scale = ops .squeeze (
262
259
reverse_embeddings_scale , axis = 0
263
260
)
264
- self ._untrack_variable (self .reverse_embeddings )
265
261
del self .reverse_embeddings
266
- else :
267
- reverse_embeddings = None
268
- reverse_embeddings_scale = None
269
- self ._int8_build (
270
- lambda shape , dtype : embeddings ,
271
- lambda shape , dtype : embeddings_scale ,
272
- lambda shape , dtype : reverse_embeddings ,
273
- lambda shape , dtype : reverse_embeddings_scale ,
274
- )
275
- else :
276
- raise self ._quantization_mode_error (mode )
277
- self ._tracker .lock ()
262
+ self .quantized_build (embeddings_shape , mode )
263
+ if mode == "int8" :
264
+ self ._embeddings .assign (embeddings )
265
+ self .embeddings_scale .assign (embeddings_scale )
266
+ if not self .tie_weights :
267
+ self .reverse_embeddings .assign (reverse_embeddings )
268
+ self .reverse_embeddings_scale .assign (reverse_embeddings_scale )
278
269
279
270
if self .dtype_policy .quantization_mode is None :
280
271
policy = keras .dtype_policies .get (
281
272
f"{ mode } _from_{ self .dtype_policy .name } "
282
273
)
283
274
self .dtype_policy = policy
284
- gc .collect ()
0 commit comments