Skip to content

Commit 3cbe70f

Browse files
committed
refactor(preprocessing/image): replace skimage downscale with TF
1 parent 90c85b8 commit 3cbe70f

File tree

1 file changed

+5
-32
lines changed

1 file changed

+5
-32
lines changed

boiling_learning/preprocessing/image.py

+5-32
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from skimage.measure import shannon_entropy
99
from skimage.metrics import normalized_mutual_information as _normalized_mutual_information
1010
from skimage.metrics import structural_similarity as ssim
11-
from skimage.transform import downscale_local_mean as _downscale
1211
from skimage.transform import resize
1312

1413
from boiling_learning.preprocessing.transformers import Operator
@@ -231,29 +230,6 @@ def downscale(
231230
) -> _VideoFrameOrFrames:
232231
# 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of shape
233232
# [height, width, channels].
234-
235-
if isinstance(factors, int):
236-
height_factor, width_factor = factors, factors
237-
else:
238-
height_factor, width_factor = factors
239-
240-
CHANNEL_FACTOR = 1
241-
if image.ndim == 3:
242-
downscale_factors = (height_factor, width_factor, CHANNEL_FACTOR)
243-
elif image.ndim == 4:
244-
BATCH_FACTOR = 1
245-
downscale_factors = (BATCH_FACTOR, height_factor, width_factor, CHANNEL_FACTOR)
246-
else:
247-
raise RuntimeError(f'image must have either 3 or 4 dimensions, got {image.ndim}')
248-
249-
return typing.cast(_VideoFrameOrFrames, _downscale(image, downscale_factors))
250-
251-
252-
def _downscale_tf(
253-
image: _VideoFrameOrFrames, factors: Union[int, Tuple[int, int]]
254-
) -> _VideoFrameOrFrames:
255-
# 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of shape
256-
# [height, width, channels].
257233
if image.ndim == 3:
258234
height = image.shape[0]
259235
width = image.shape[1]
@@ -263,17 +239,14 @@ def _downscale_tf(
263239
else:
264240
raise RuntimeError(f'image must have either 3 or 4 dimensions, got {image.ndim}')
265241

266-
if isinstance(factors, int):
267-
height_factor, width_factor = factors, factors
268-
else:
269-
height_factor, width_factor = factors
242+
height_factor, width_factor = (factors, factors) if isinstance(factors, int) else factors
243+
244+
new_height = round(height / height_factor)
245+
new_width = round(width / width_factor)
270246

271247
return typing.cast(
272248
_VideoFrameOrFrames,
273-
tf.image.resize(
274-
image,
275-
(height // height_factor, width // width_factor),
276-
).numpy(),
249+
tf.image.resize(image, (new_height, new_width), antialias=True).numpy(),
277250
)
278251

279252

0 commit comments

Comments
 (0)