Skip to content

Commit 2a905c9

Browse files
No public description
PiperOrigin-RevId: 615388881
1 parent 79b9e98 commit 2a905c9

File tree

5 files changed

+201
-62
lines changed

5 files changed

+201
-62
lines changed

official/vision/configs/semantic_segmentation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class DataConfig(cfg.DataConfig):
9191
)
9292
additional_dense_features: List[DenseFeatureConfig] = dataclasses.field(
9393
default_factory=list)
94+
# If `centered_crop` is set to True, then resized crop
95+
# (if smaller than padded size) is place in the center of the image.
96+
# Default behaviour is to place it at left top corner.
97+
centered_crop: bool = False
9498

9599

96100
@dataclasses.dataclass

official/vision/dataloaders/segmentation_input.py

Lines changed: 117 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,54 @@
2525
class Decoder(decoder.Decoder):
2626
"""A tf.Example decoder for segmentation task."""
2727

28-
def __init__(self,
29-
image_feature=config_lib.DenseFeatureConfig(),
30-
additional_dense_features=None):
28+
def __init__(
29+
self,
30+
image_feature=config_lib.DenseFeatureConfig(),
31+
additional_dense_features=None,
32+
):
3133
self._keys_to_features = {
32-
'image/encoded':
33-
tf.io.FixedLenFeature((), tf.string, default_value=''),
34-
'image/height':
35-
tf.io.FixedLenFeature((), tf.int64, default_value=0),
36-
'image/width':
37-
tf.io.FixedLenFeature((), tf.int64, default_value=0),
38-
'image/segmentation/class/encoded':
39-
tf.io.FixedLenFeature((), tf.string, default_value=''),
40-
image_feature.feature_name:
41-
tf.io.FixedLenFeature((), tf.string, default_value='')
34+
'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''),
35+
'image/height': tf.io.FixedLenFeature((), tf.int64, default_value=0),
36+
'image/width': tf.io.FixedLenFeature((), tf.int64, default_value=0),
37+
'image/segmentation/class/encoded': tf.io.FixedLenFeature(
38+
(), tf.string, default_value=''
39+
),
40+
image_feature.feature_name: tf.io.FixedLenFeature(
41+
(), tf.string, default_value=''
42+
),
4243
}
4344
if additional_dense_features:
4445
for feature in additional_dense_features:
4546
self._keys_to_features[feature.feature_name] = tf.io.FixedLenFeature(
46-
(), tf.string, default_value='')
47+
(), tf.string, default_value=''
48+
)
4749

4850
def decode(self, serialized_example):
49-
return tf.io.parse_single_example(serialized_example,
50-
self._keys_to_features)
51+
return tf.io.parse_single_example(
52+
serialized_example, self._keys_to_features
53+
)
5154

5255

5356
class Parser(parser.Parser):
5457
"""Parser to parse an image and its annotations into a dictionary of tensors."""
5558

56-
def __init__(self,
57-
output_size,
58-
crop_size=None,
59-
resize_eval_groundtruth=True,
60-
gt_is_matting_map=False,
61-
groundtruth_padded_size=None,
62-
ignore_label=255,
63-
aug_rand_hflip=False,
64-
preserve_aspect_ratio=True,
65-
aug_scale_min=1.0,
66-
aug_scale_max=1.0,
67-
dtype='float32',
68-
image_feature=config_lib.DenseFeatureConfig(),
69-
additional_dense_features=None):
59+
def __init__(
60+
self,
61+
output_size,
62+
crop_size=None,
63+
resize_eval_groundtruth=True,
64+
gt_is_matting_map=False,
65+
groundtruth_padded_size=None,
66+
ignore_label=255,
67+
aug_rand_hflip=False,
68+
preserve_aspect_ratio=True,
69+
aug_scale_min=1.0,
70+
aug_scale_max=1.0,
71+
dtype='float32',
72+
image_feature=config_lib.DenseFeatureConfig(),
73+
additional_dense_features=None,
74+
centered_crop=False,
75+
):
7076
"""Initializes parameters for parsing annotations in the dataset.
7177
7278
Args:
@@ -100,13 +106,18 @@ def __init__(self,
100106
dataset mean/stddev.
101107
additional_dense_features: `list` of DenseFeatureConfig for additional
102108
dense features.
109+
centered_crop: If `centered_crop` is set to True, then resized crop (if
110+
smaller than padded size) is place in the center of the image. Default
111+
behaviour is to place it at left top corner.
103112
"""
104113
self._output_size = output_size
105114
self._crop_size = crop_size
106115
self._resize_eval_groundtruth = resize_eval_groundtruth
107116
if (not resize_eval_groundtruth) and (groundtruth_padded_size is None):
108-
raise ValueError('groundtruth_padded_size ([height, width]) needs to be'
109-
'specified when resize_eval_groundtruth is False.')
117+
raise ValueError(
118+
'groundtruth_padded_size ([height, width]) needs to be'
119+
'specified when resize_eval_groundtruth is False.'
120+
)
110121
self._gt_is_matting_map = gt_is_matting_map
111122
self._groundtruth_padded_size = groundtruth_padded_size
112123
self._ignore_label = ignore_label
@@ -122,40 +133,53 @@ def __init__(self,
122133

123134
self._image_feature = image_feature
124135
self._additional_dense_features = additional_dense_features
136+
self._centered_crop = centered_crop
137+
if self._centered_crop and not self._resize_eval_groundtruth:
138+
raise ValueError(
139+
'centered_crop is only supported when resize_eval_groundtruth is'
140+
' True.'
141+
)
125142

126143
def _prepare_image_and_label(self, data):
127144
"""Prepare normalized image and label."""
128145
height = data['image/height']
129146
width = data['image/width']
130147

131148
label = tf.io.decode_image(
132-
data['image/segmentation/class/encoded'], channels=1)
149+
data['image/segmentation/class/encoded'], channels=1
150+
)
133151
label = tf.reshape(label, (1, height, width))
134152
label = tf.cast(label, tf.float32)
135153

136154
image = tf.io.decode_image(
137155
data[self._image_feature.feature_name],
138156
channels=self._image_feature.num_channels,
139-
dtype=tf.uint8)
157+
dtype=tf.uint8,
158+
)
140159
image = tf.reshape(image, (height, width, self._image_feature.num_channels))
141160
# Normalizes the image feature with mean and std values, which are divided
142161
# by 255 because an uint8 image are re-scaled automatically. Images other
143162
# than uint8 type will be wrongly normalized.
144163
image = preprocess_ops.normalize_image(
145-
image, [mean / 255.0 for mean in self._image_feature.mean],
146-
[stddev / 255.0 for stddev in self._image_feature.stddev])
164+
image,
165+
[mean / 255.0 for mean in self._image_feature.mean],
166+
[stddev / 255.0 for stddev in self._image_feature.stddev],
167+
)
147168

148169
if self._additional_dense_features:
149170
input_list = [image]
150171
for feature_cfg in self._additional_dense_features:
151172
feature = tf.io.decode_image(
152173
data[feature_cfg.feature_name],
153174
channels=feature_cfg.num_channels,
154-
dtype=tf.uint8)
175+
dtype=tf.uint8,
176+
)
155177
feature = tf.reshape(feature, (height, width, feature_cfg.num_channels))
156178
feature = preprocess_ops.normalize_image(
157-
feature, [mean / 255.0 for mean in feature_cfg.mean],
158-
[stddev / 255.0 for stddev in feature_cfg.stddev])
179+
feature,
180+
[mean / 255.0 for mean in feature_cfg.mean],
181+
[stddev / 255.0 for stddev in feature_cfg.stddev],
182+
)
159183
input_list.append(feature)
160184
concat_input = tf.concat(input_list, axis=2)
161185
else:
@@ -164,7 +188,8 @@ def _prepare_image_and_label(self, data):
164188
if not self._preserve_aspect_ratio:
165189
label = tf.reshape(label, [data['image/height'], data['image/width'], 1])
166190
concat_input = tf.image.resize(
167-
concat_input, self._output_size, method='bilinear')
191+
concat_input, self._output_size, method='bilinear'
192+
)
168193
label = tf.image.resize(label, self._output_size, method='nearest')
169194
label = tf.reshape(label[:, :, -1], [1] + self._output_size)
170195

@@ -195,14 +220,16 @@ def _parse_train_data(self, data):
195220

196221
image_mask = tf.concat([image, label], axis=2)
197222
image_mask_crop = tf.image.random_crop(
198-
image_mask, self._crop_size + [tf.shape(image_mask)[-1]])
223+
image_mask, self._crop_size + [tf.shape(image_mask)[-1]]
224+
)
199225
image = image_mask_crop[:, :, :-1]
200226
label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._crop_size)
201227

202228
# Flips image randomly during training.
203229
if self._aug_rand_hflip:
204230
image, _, label = preprocess_ops.random_horizontal_flip(
205-
image, masks=label)
231+
image, masks=label
232+
)
206233

207234
train_image_size = self._crop_size if self._crop_size else self._output_size
208235
# Resizes and crops image.
@@ -211,7 +238,9 @@ def _parse_train_data(self, data):
211238
train_image_size,
212239
train_image_size,
213240
aug_scale_min=self._aug_scale_min,
214-
aug_scale_max=self._aug_scale_max)
241+
aug_scale_max=self._aug_scale_max,
242+
centered_crop=self._centered_crop,
243+
)
215244

216245
# Resizes and crops boxes.
217246
image_scale = image_info[2, :]
@@ -221,11 +250,17 @@ def _parse_train_data(self, data):
221250
# The label is first offset by +1 and then padded with 0.
222251
label += 1
223252
label = tf.expand_dims(label, axis=3)
224-
label = preprocess_ops.resize_and_crop_masks(label, image_scale,
225-
train_image_size, offset)
253+
label = preprocess_ops.resize_and_crop_masks(
254+
label,
255+
image_scale,
256+
train_image_size,
257+
offset,
258+
centered_crop=self._centered_crop,
259+
)
226260
label -= 1
227261
label = tf.where(
228-
tf.equal(label, -1), self._ignore_label * tf.ones_like(label), label)
262+
tf.equal(label, -1), self._ignore_label * tf.ones_like(label), label
263+
)
229264
label = tf.squeeze(label, axis=0)
230265
valid_mask = tf.not_equal(label, self._ignore_label)
231266

@@ -255,30 +290,58 @@ def _parse_eval_data(self, data):
255290

256291
# Resizes and crops image.
257292
image, image_info = preprocess_ops.resize_and_crop_image(
258-
image, self._output_size, self._output_size)
293+
image,
294+
self._output_size,
295+
self._output_size,
296+
centered_crop=self._centered_crop,
297+
)
259298

260299
if self._resize_eval_groundtruth:
261300
# Resizes eval masks to match input image sizes. In that case, mean IoU
262301
# is computed on output_size not the original size of the images.
263302
image_scale = image_info[2, :]
264303
offset = image_info[3, :]
265-
label = preprocess_ops.resize_and_crop_masks(label, image_scale,
266-
self._output_size, offset)
304+
label = preprocess_ops.resize_and_crop_masks(
305+
label,
306+
image_scale,
307+
self._output_size,
308+
offset,
309+
centered_crop=self._centered_crop,
310+
)
267311
else:
268-
label = tf.image.pad_to_bounding_box(label, 0, 0,
269-
self._groundtruth_padded_size[0],
270-
self._groundtruth_padded_size[1])
312+
if self._centered_crop:
313+
label_size = tf.cast(tf.shape(label)[0:2], tf.int32)
314+
label = tf.image.pad_to_bounding_box(
315+
label,
316+
tf.maximum(
317+
(self._groundtruth_padded_size[0] - label_size[0]) // 2, 0
318+
),
319+
tf.maximum(
320+
(self._groundtruth_padded_size[1] - label_size[1]) // 2, 0
321+
),
322+
self._groundtruth_padded_size[0],
323+
self._groundtruth_padded_size[1],
324+
)
325+
else:
326+
label = tf.image.pad_to_bounding_box(
327+
label,
328+
0,
329+
0,
330+
self._groundtruth_padded_size[0],
331+
self._groundtruth_padded_size[1],
332+
)
271333

272334
label -= 1
273335
label = tf.where(
274-
tf.equal(label, -1), self._ignore_label * tf.ones_like(label), label)
336+
tf.equal(label, -1), self._ignore_label * tf.ones_like(label), label
337+
)
275338
label = tf.squeeze(label, axis=0)
276339

277340
valid_mask = tf.not_equal(label, self._ignore_label)
278341
labels = {
279342
'masks': label,
280343
'valid_masks': valid_mask,
281-
'image_info': image_info
344+
'image_info': image_info,
282345
}
283346

284347
# Cast image as self._dtype

official/vision/ops/preprocess_ops.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def resize_and_crop_image(
168168
seed=1,
169169
method=tf.image.ResizeMethod.BILINEAR,
170170
keep_aspect_ratio=True,
171+
centered_crop=False,
171172
):
172173
"""Resizes the input image to output size (RetinaNet style).
173174
@@ -195,6 +196,9 @@ def resize_and_crop_image(
195196
seed: seed for random scale jittering.
196197
method: function to resize input image to scaled image.
197198
keep_aspect_ratio: whether or not to keep the aspect ratio when resizing.
199+
centered_crop: If `centered_crop` is set to True, then resized crop (if
200+
smaller than padded size) is place in the center of the image. Default
201+
behaviour is to place it at left top corner.
198202
199203
Returns:
200204
output_image: `Tensor` of shape [height, width, 3] where [height, width]
@@ -266,9 +270,19 @@ def resize_and_crop_image(
266270

267271
output_image = scaled_image
268272
if padded_size is not None:
269-
output_image = tf.image.pad_to_bounding_box(
270-
scaled_image, 0, 0, padded_size[0], padded_size[1]
271-
)
273+
if centered_crop:
274+
scaled_image_size = tf.cast(tf.shape(scaled_image)[0:2], tf.int32)
275+
output_image = tf.image.pad_to_bounding_box(
276+
scaled_image,
277+
tf.maximum((padded_size[0] - scaled_image_size[0]) // 2, 0),
278+
tf.maximum((padded_size[1] - scaled_image_size[1]) // 2, 0),
279+
padded_size[0],
280+
padded_size[1],
281+
)
282+
else:
283+
output_image = tf.image.pad_to_bounding_box(
284+
scaled_image, 0, 0, padded_size[0], padded_size[1]
285+
)
272286

273287
image_info = tf.stack([
274288
image_size,
@@ -686,7 +700,9 @@ def resize_and_crop_boxes(boxes, image_scale, output_size, offset):
686700
return boxes
687701

688702

689-
def resize_and_crop_masks(masks, image_scale, output_size, offset):
703+
def resize_and_crop_masks(
704+
masks, image_scale, output_size, offset, centered_crop: bool = False
705+
):
690706
"""Resizes boxes to output size with scale and offset.
691707
692708
Args:
@@ -697,6 +713,9 @@ def resize_and_crop_masks(masks, image_scale, output_size, offset):
697713
output image size.
698714
offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
699715
boxes.
716+
centered_crop: If `centered_crop` is set to True, then resized crop (if
717+
smaller than padded size) is place in the center of the image. Default
718+
behaviour is to place it at left top corner.
700719
701720
Returns:
702721
masks: `Tensor` of shape [N, H, W, C] representing the scaled masks.
@@ -719,6 +738,7 @@ def resize_and_crop_masks(masks, image_scale, output_size, offset):
719738
scaled_masks = tf.image.resize(
720739
masks, scaled_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
721740
)
741+
722742
offset = tf.cast(offset, tf.int32)
723743
scaled_masks = scaled_masks[
724744
:,
@@ -727,9 +747,20 @@ def resize_and_crop_masks(masks, image_scale, output_size, offset):
727747
:,
728748
]
729749

730-
output_masks = tf.image.pad_to_bounding_box(
731-
scaled_masks, 0, 0, output_size[0], output_size[1]
732-
)
750+
if centered_crop:
751+
scaled_mask_size = tf.cast(tf.shape(scaled_masks)[1:3], tf.int32)
752+
output_masks = tf.image.pad_to_bounding_box(
753+
scaled_masks,
754+
tf.maximum((output_size[0] - scaled_mask_size[0]) // 2, 0),
755+
tf.maximum((output_size[1] - scaled_mask_size[1]) // 2, 0),
756+
output_size[0],
757+
output_size[1],
758+
)
759+
else:
760+
output_masks = tf.image.pad_to_bounding_box(
761+
scaled_masks, 0, 0, output_size[0], output_size[1]
762+
)
763+
733764
# Remove padding.
734765
output_masks = output_masks[1::]
735766
return output_masks

0 commit comments

Comments
 (0)