@@ -33,12 +33,7 @@ def _build_model(self):
33
33
model_config = self .params .task .model ,
34
34
l2_regularizer = None )
35
35
36
- def _build_inputs (self , image ):
37
- """Builds classification model inputs for serving."""
38
- # Center crops and resizes image.
39
- if isinstance (image , tf .RaggedTensor ):
40
- image = image .to_tensor ()
41
- image = tf .cast (image , dtype = tf .float32 )
36
+ def _crop_and_resize (self , image ):
42
37
if self .params .task .train_data .aug_crop :
43
38
image = preprocess_ops .center_crop_image (image )
44
39
@@ -48,6 +43,21 @@ def _build_inputs(self, image):
48
43
image = tf .reshape (
49
44
image , [self ._input_image_size [0 ], self ._input_image_size [1 ], 3 ])
50
45
46
+ return image
47
+
48
+ def _build_inputs (self , image ):
49
+ """Builds classification model inputs for serving."""
50
+ # Center crops and resizes image.
51
+ if isinstance (image , tf .RaggedTensor ):
52
+ image = image .to_tensor ()
53
+ image = tf .cast (image , dtype = tf .float32 )
54
+
55
+ # For these input types, decode_image already performs cropping.
56
+ if not (
57
+ self ._input_type in ['tf_example' , 'image_bytes' ]
58
+ and len (self ._input_image_size ) == 2 ):
59
+ image = self ._crop_and_resize (image )
60
+
51
61
# Normalizes image with mean and std pixel values.
52
62
image = preprocess_ops .normalize_image (
53
63
image , offset = preprocess_ops .MEAN_RGB , scale = preprocess_ops .STDDEV_RGB )
@@ -72,14 +82,12 @@ def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
72
82
encoded_image_bytes , channels = self ._num_channels
73
83
)
74
84
image_tensor .set_shape ((None , None , self ._num_channels ))
75
- # Resize image to input_size to support varible image resolutions in a
76
- # batch for tf_example and image_bytes input type.
77
- image_tensor = tf .image .resize (
78
- tf .cast (image_tensor , tf .float32 ),
79
- self ._input_image_size ,
80
- method = tf .image .ResizeMethod .BILINEAR ,
81
- )
85
+ # Crop the image inside the same loop as decoding an image
86
+ # if there could be several images of different sizes in the batch.
87
+ image_tensor = tf .cast (image_tensor , dtype = tf .float32 )
88
+ image_tensor = self ._crop_and_resize (image_tensor )
82
89
image_tensor = tf .cast (image_tensor , tf .uint8 )
90
+ return image_tensor
83
91
else :
84
92
# Convert raw bytes into a tensor and reshape it, if not 2D input.
85
93
image_tensor = tf .io .decode_raw (encoded_image_bytes , out_type = tf .uint8 )
0 commit comments