|
16 | 16 | from keras_hub.src.utils.preset_utils import get_preset_saver
|
17 | 17 | from keras_hub.src.utils.python_utils import classproperty
|
18 | 18 | from keras_hub.src.utils.tensor_utils import check_bounding_box_support
|
| 19 | +from keras_hub.src.utils.tensor_utils import in_tf_function |
19 | 20 | from keras_hub.src.utils.tensor_utils import preprocessing_function
|
20 | 21 |
|
21 | 22 |
|
@@ -270,33 +271,45 @@ def call(self, inputs):
|
270 | 271 | else:
|
271 | 272 | x = inputs
|
272 | 273 | if self.scale is not None:
|
273 |
| - x = x * self._expand_non_channel_dims(self.scale, x) |
| 274 | + # If we are scaling always cast to the compute dtype. We can't |
| 275 | + # leave things as an int type if we are scaling to [0, 1]. |
| 276 | + scale = self._expand_non_channel_dims(self.scale, x) |
| 277 | + x, scale = self._convert_types(x, scale, self.compute_dtype) |
| 278 | + x = x * scale |
274 | 279 | if self.offset is not None:
|
275 |
| - x = x + self._expand_non_channel_dims(self.offset, x) |
| 280 | + offset = self._expand_non_channel_dims(self.offset, x) |
| 281 | + x, offset = self._convert_types(x, offset, x.dtype) |
| 282 | + x = x + offset |
276 | 283 | if isinstance(inputs, dict):
|
277 | 284 | inputs["images"] = x
|
278 | 285 | else:
|
279 | 286 | inputs = x
|
280 | 287 | return inputs
|
281 | 288 |
|
282 | 289 | def _expand_non_channel_dims(self, value, inputs):
|
| 290 | + """Expand non channel dims so value is broadcastable with inputs.""" |
283 | 291 | unbatched = len(ops.shape(inputs)) == 3
|
284 | 292 | channels_first = self.data_format == "channels_first"
|
285 | 293 | if unbatched:
|
286 | 294 | broadcast_dims = (1, 2) if channels_first else (0, 1)
|
287 | 295 | else:
|
288 | 296 | broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
|
289 |
| - # If inputs are not a tensor type, return a numpy array. |
290 |
| - # This might happen when running under tf.data. |
291 |
| - if ops.is_tensor(inputs): |
292 |
| - # preprocessing decorator moves tensors to cpu in torch backend and |
293 |
| - # processed on CPU, and then converted back to the appropriate |
294 |
| - # device (potentially GPU) after preprocessing. |
295 |
| - if keras.backend.backend() == "torch" and self.image_size is None: |
296 |
| - return ops.expand_dims(value, broadcast_dims).cpu() |
297 |
| - return ops.expand_dims(value, broadcast_dims) |
298 |
| - else: |
299 |
| - return np.expand_dims(value, broadcast_dims) |
| 297 | + # An numpy value will work backend native ops or with tf.data. |
| 298 | + return np.expand_dims(value, broadcast_dims) |
| 299 | + |
| 300 | + def _convert_types(self, x, y, dtype): |
| 301 | + """Make sure x and y have the same dtype and are on ths same device.""" |
| 302 | + if in_tf_function(): |
| 303 | + # This could happen on any backend if we are running in tf.data. |
| 304 | + import tensorflow as tf |
| 305 | + |
| 306 | + return tf.cast(x, dtype), tf.cast(y, dtype) |
| 307 | + x = ops.cast(x, dtype) |
| 308 | + y = ops.cast(y, dtype) |
| 309 | + if keras.backend.backend() == "torch": |
| 310 | + # Place on the same device as x (the image). |
| 311 | + y = y.to(x.device) |
| 312 | + return x, y |
300 | 313 |
|
301 | 314 | def get_config(self):
|
302 | 315 | config = super().get_config()
|
|
0 commit comments