Skip to content

Commit 79b9e98

Browse files
MInnertensorflower-gardener
authored andcommitted
No public description
PiperOrigin-RevId: 615203446
1 parent 12b8e1d commit 79b9e98

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

official/vision/serving/image_classification.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,7 @@ def _build_model(self):
3333
model_config=self.params.task.model,
3434
l2_regularizer=None)
3535

36-
def _build_inputs(self, image):
37-
"""Builds classification model inputs for serving."""
38-
# Center crops and resizes image.
39-
if isinstance(image, tf.RaggedTensor):
40-
image = image.to_tensor()
41-
image = tf.cast(image, dtype=tf.float32)
36+
def _crop_and_resize(self, image):
4237
if self.params.task.train_data.aug_crop:
4338
image = preprocess_ops.center_crop_image(image)
4439

@@ -48,6 +43,21 @@ def _build_inputs(self, image):
4843
image = tf.reshape(
4944
image, [self._input_image_size[0], self._input_image_size[1], 3])
5045

46+
return image
47+
48+
def _build_inputs(self, image):
49+
"""Builds classification model inputs for serving."""
50+
# Center crops and resizes image.
51+
if isinstance(image, tf.RaggedTensor):
52+
image = image.to_tensor()
53+
image = tf.cast(image, dtype=tf.float32)
54+
55+
# For these input types, decode_image already performs cropping.
56+
if not (
57+
self._input_type in ['tf_example', 'image_bytes']
58+
and len(self._input_image_size) == 2):
59+
image = self._crop_and_resize(image)
60+
5161
# Normalizes image with mean and std pixel values.
5262
image = preprocess_ops.normalize_image(
5363
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
@@ -72,14 +82,12 @@ def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
7282
encoded_image_bytes, channels=self._num_channels
7383
)
7484
image_tensor.set_shape((None, None, self._num_channels))
75-
# Resize image to input_size to support varible image resolutions in a
76-
# batch for tf_example and image_bytes input type.
77-
image_tensor = tf.image.resize(
78-
tf.cast(image_tensor, tf.float32),
79-
self._input_image_size,
80-
method=tf.image.ResizeMethod.BILINEAR,
81-
)
85+
# Crop the image inside the same loop as decoding an image
86+
# if there could be several images of different sizes in the batch.
87+
image_tensor = tf.cast(image_tensor, dtype=tf.float32)
88+
image_tensor = self._crop_and_resize(image_tensor)
8289
image_tensor = tf.cast(image_tensor, tf.uint8)
90+
return image_tensor
8391
else:
8492
# Convert raw bytes into a tensor and reshape it, if not 2D input.
8593
image_tensor = tf.io.decode_raw(encoded_image_bytes, out_type=tf.uint8)

0 commit comments

Comments
 (0)