25
25
class Decoder (decoder .Decoder ):
26
26
"""A tf.Example decoder for segmentation task."""
27
27
28
- def __init__ (self ,
29
- image_feature = config_lib .DenseFeatureConfig (),
30
- additional_dense_features = None ):
28
+ def __init__ (
29
+ self ,
30
+ image_feature = config_lib .DenseFeatureConfig (),
31
+ additional_dense_features = None ,
32
+ ):
31
33
self ._keys_to_features = {
32
- 'image/encoded' :
33
- tf .io .FixedLenFeature ((), tf .string , default_value = '' ),
34
- 'image/height' :
35
- tf .io .FixedLenFeature ((), tf .int64 , default_value = 0 ),
36
- 'image/width' :
37
- tf .io .FixedLenFeature ((), tf .int64 , default_value = 0 ),
38
- 'image/segmentation/class/encoded' :
39
- tf .io .FixedLenFeature ((), tf .string , default_value = '' ),
40
- image_feature .feature_name :
41
- tf .io .FixedLenFeature ((), tf .string , default_value = '' )
34
+ 'image/encoded' : tf .io .FixedLenFeature ((), tf .string , default_value = '' ),
35
+ 'image/height' : tf .io .FixedLenFeature ((), tf .int64 , default_value = 0 ),
36
+ 'image/width' : tf .io .FixedLenFeature ((), tf .int64 , default_value = 0 ),
37
+ 'image/segmentation/class/encoded' : tf .io .FixedLenFeature (
38
+ (), tf .string , default_value = ''
39
+ ),
40
+ image_feature .feature_name : tf .io .FixedLenFeature (
41
+ (), tf .string , default_value = ''
42
+ ),
42
43
}
43
44
if additional_dense_features :
44
45
for feature in additional_dense_features :
45
46
self ._keys_to_features [feature .feature_name ] = tf .io .FixedLenFeature (
46
- (), tf .string , default_value = '' )
47
+ (), tf .string , default_value = ''
48
+ )
47
49
48
50
def decode (self , serialized_example ):
49
- return tf .io .parse_single_example (serialized_example ,
50
- self ._keys_to_features )
51
+ return tf .io .parse_single_example (
52
+ serialized_example , self ._keys_to_features
53
+ )
51
54
52
55
53
56
class Parser (parser .Parser ):
54
57
"""Parser to parse an image and its annotations into a dictionary of tensors."""
55
58
56
- def __init__ (self ,
57
- output_size ,
58
- crop_size = None ,
59
- resize_eval_groundtruth = True ,
60
- gt_is_matting_map = False ,
61
- groundtruth_padded_size = None ,
62
- ignore_label = 255 ,
63
- aug_rand_hflip = False ,
64
- preserve_aspect_ratio = True ,
65
- aug_scale_min = 1.0 ,
66
- aug_scale_max = 1.0 ,
67
- dtype = 'float32' ,
68
- image_feature = config_lib .DenseFeatureConfig (),
69
- additional_dense_features = None ):
59
+ def __init__ (
60
+ self ,
61
+ output_size ,
62
+ crop_size = None ,
63
+ resize_eval_groundtruth = True ,
64
+ gt_is_matting_map = False ,
65
+ groundtruth_padded_size = None ,
66
+ ignore_label = 255 ,
67
+ aug_rand_hflip = False ,
68
+ preserve_aspect_ratio = True ,
69
+ aug_scale_min = 1.0 ,
70
+ aug_scale_max = 1.0 ,
71
+ dtype = 'float32' ,
72
+ image_feature = config_lib .DenseFeatureConfig (),
73
+ additional_dense_features = None ,
74
+ centered_crop = False ,
75
+ ):
70
76
"""Initializes parameters for parsing annotations in the dataset.
71
77
72
78
Args:
@@ -100,13 +106,18 @@ def __init__(self,
100
106
dataset mean/stddev.
101
107
additional_dense_features: `list` of DenseFeatureConfig for additional
102
108
dense features.
109
+ centered_crop: If `centered_crop` is set to True, then resized crop (if
110
+ smaller than padded size) is place in the center of the image. Default
111
+ behaviour is to place it at left top corner.
103
112
"""
104
113
self ._output_size = output_size
105
114
self ._crop_size = crop_size
106
115
self ._resize_eval_groundtruth = resize_eval_groundtruth
107
116
if (not resize_eval_groundtruth ) and (groundtruth_padded_size is None ):
108
- raise ValueError ('groundtruth_padded_size ([height, width]) needs to be'
109
- 'specified when resize_eval_groundtruth is False.' )
117
+ raise ValueError (
118
+ 'groundtruth_padded_size ([height, width]) needs to be'
119
+ 'specified when resize_eval_groundtruth is False.'
120
+ )
110
121
self ._gt_is_matting_map = gt_is_matting_map
111
122
self ._groundtruth_padded_size = groundtruth_padded_size
112
123
self ._ignore_label = ignore_label
@@ -122,40 +133,53 @@ def __init__(self,
122
133
123
134
self ._image_feature = image_feature
124
135
self ._additional_dense_features = additional_dense_features
136
+ self ._centered_crop = centered_crop
137
+ if self ._centered_crop and not self ._resize_eval_groundtruth :
138
+ raise ValueError (
139
+ 'centered_crop is only supported when resize_eval_groundtruth is'
140
+ ' True.'
141
+ )
125
142
126
143
def _prepare_image_and_label (self , data ):
127
144
"""Prepare normalized image and label."""
128
145
height = data ['image/height' ]
129
146
width = data ['image/width' ]
130
147
131
148
label = tf .io .decode_image (
132
- data ['image/segmentation/class/encoded' ], channels = 1 )
149
+ data ['image/segmentation/class/encoded' ], channels = 1
150
+ )
133
151
label = tf .reshape (label , (1 , height , width ))
134
152
label = tf .cast (label , tf .float32 )
135
153
136
154
image = tf .io .decode_image (
137
155
data [self ._image_feature .feature_name ],
138
156
channels = self ._image_feature .num_channels ,
139
- dtype = tf .uint8 )
157
+ dtype = tf .uint8 ,
158
+ )
140
159
image = tf .reshape (image , (height , width , self ._image_feature .num_channels ))
141
160
# Normalizes the image feature with mean and std values, which are divided
142
161
# by 255 because an uint8 image are re-scaled automatically. Images other
143
162
# than uint8 type will be wrongly normalized.
144
163
image = preprocess_ops .normalize_image (
145
- image , [mean / 255.0 for mean in self ._image_feature .mean ],
146
- [stddev / 255.0 for stddev in self ._image_feature .stddev ])
164
+ image ,
165
+ [mean / 255.0 for mean in self ._image_feature .mean ],
166
+ [stddev / 255.0 for stddev in self ._image_feature .stddev ],
167
+ )
147
168
148
169
if self ._additional_dense_features :
149
170
input_list = [image ]
150
171
for feature_cfg in self ._additional_dense_features :
151
172
feature = tf .io .decode_image (
152
173
data [feature_cfg .feature_name ],
153
174
channels = feature_cfg .num_channels ,
154
- dtype = tf .uint8 )
175
+ dtype = tf .uint8 ,
176
+ )
155
177
feature = tf .reshape (feature , (height , width , feature_cfg .num_channels ))
156
178
feature = preprocess_ops .normalize_image (
157
- feature , [mean / 255.0 for mean in feature_cfg .mean ],
158
- [stddev / 255.0 for stddev in feature_cfg .stddev ])
179
+ feature ,
180
+ [mean / 255.0 for mean in feature_cfg .mean ],
181
+ [stddev / 255.0 for stddev in feature_cfg .stddev ],
182
+ )
159
183
input_list .append (feature )
160
184
concat_input = tf .concat (input_list , axis = 2 )
161
185
else :
@@ -164,7 +188,8 @@ def _prepare_image_and_label(self, data):
164
188
if not self ._preserve_aspect_ratio :
165
189
label = tf .reshape (label , [data ['image/height' ], data ['image/width' ], 1 ])
166
190
concat_input = tf .image .resize (
167
- concat_input , self ._output_size , method = 'bilinear' )
191
+ concat_input , self ._output_size , method = 'bilinear'
192
+ )
168
193
label = tf .image .resize (label , self ._output_size , method = 'nearest' )
169
194
label = tf .reshape (label [:, :, - 1 ], [1 ] + self ._output_size )
170
195
@@ -195,14 +220,16 @@ def _parse_train_data(self, data):
195
220
196
221
image_mask = tf .concat ([image , label ], axis = 2 )
197
222
image_mask_crop = tf .image .random_crop (
198
- image_mask , self ._crop_size + [tf .shape (image_mask )[- 1 ]])
223
+ image_mask , self ._crop_size + [tf .shape (image_mask )[- 1 ]]
224
+ )
199
225
image = image_mask_crop [:, :, :- 1 ]
200
226
label = tf .reshape (image_mask_crop [:, :, - 1 ], [1 ] + self ._crop_size )
201
227
202
228
# Flips image randomly during training.
203
229
if self ._aug_rand_hflip :
204
230
image , _ , label = preprocess_ops .random_horizontal_flip (
205
- image , masks = label )
231
+ image , masks = label
232
+ )
206
233
207
234
train_image_size = self ._crop_size if self ._crop_size else self ._output_size
208
235
# Resizes and crops image.
@@ -211,7 +238,9 @@ def _parse_train_data(self, data):
211
238
train_image_size ,
212
239
train_image_size ,
213
240
aug_scale_min = self ._aug_scale_min ,
214
- aug_scale_max = self ._aug_scale_max )
241
+ aug_scale_max = self ._aug_scale_max ,
242
+ centered_crop = self ._centered_crop ,
243
+ )
215
244
216
245
# Resizes and crops boxes.
217
246
image_scale = image_info [2 , :]
@@ -221,11 +250,17 @@ def _parse_train_data(self, data):
221
250
# The label is first offset by +1 and then padded with 0.
222
251
label += 1
223
252
label = tf .expand_dims (label , axis = 3 )
224
- label = preprocess_ops .resize_and_crop_masks (label , image_scale ,
225
- train_image_size , offset )
253
+ label = preprocess_ops .resize_and_crop_masks (
254
+ label ,
255
+ image_scale ,
256
+ train_image_size ,
257
+ offset ,
258
+ centered_crop = self ._centered_crop ,
259
+ )
226
260
label -= 1
227
261
label = tf .where (
228
- tf .equal (label , - 1 ), self ._ignore_label * tf .ones_like (label ), label )
262
+ tf .equal (label , - 1 ), self ._ignore_label * tf .ones_like (label ), label
263
+ )
229
264
label = tf .squeeze (label , axis = 0 )
230
265
valid_mask = tf .not_equal (label , self ._ignore_label )
231
266
@@ -255,30 +290,58 @@ def _parse_eval_data(self, data):
255
290
256
291
# Resizes and crops image.
257
292
image , image_info = preprocess_ops .resize_and_crop_image (
258
- image , self ._output_size , self ._output_size )
293
+ image ,
294
+ self ._output_size ,
295
+ self ._output_size ,
296
+ centered_crop = self ._centered_crop ,
297
+ )
259
298
260
299
if self ._resize_eval_groundtruth :
261
300
# Resizes eval masks to match input image sizes. In that case, mean IoU
262
301
# is computed on output_size not the original size of the images.
263
302
image_scale = image_info [2 , :]
264
303
offset = image_info [3 , :]
265
- label = preprocess_ops .resize_and_crop_masks (label , image_scale ,
266
- self ._output_size , offset )
304
+ label = preprocess_ops .resize_and_crop_masks (
305
+ label ,
306
+ image_scale ,
307
+ self ._output_size ,
308
+ offset ,
309
+ centered_crop = self ._centered_crop ,
310
+ )
267
311
else :
268
- label = tf .image .pad_to_bounding_box (label , 0 , 0 ,
269
- self ._groundtruth_padded_size [0 ],
270
- self ._groundtruth_padded_size [1 ])
312
+ if self ._centered_crop :
313
+ label_size = tf .cast (tf .shape (label )[0 :2 ], tf .int32 )
314
+ label = tf .image .pad_to_bounding_box (
315
+ label ,
316
+ tf .maximum (
317
+ (self ._groundtruth_padded_size [0 ] - label_size [0 ]) // 2 , 0
318
+ ),
319
+ tf .maximum (
320
+ (self ._groundtruth_padded_size [1 ] - label_size [1 ]) // 2 , 0
321
+ ),
322
+ self ._groundtruth_padded_size [0 ],
323
+ self ._groundtruth_padded_size [1 ],
324
+ )
325
+ else :
326
+ label = tf .image .pad_to_bounding_box (
327
+ label ,
328
+ 0 ,
329
+ 0 ,
330
+ self ._groundtruth_padded_size [0 ],
331
+ self ._groundtruth_padded_size [1 ],
332
+ )
271
333
272
334
label -= 1
273
335
label = tf .where (
274
- tf .equal (label , - 1 ), self ._ignore_label * tf .ones_like (label ), label )
336
+ tf .equal (label , - 1 ), self ._ignore_label * tf .ones_like (label ), label
337
+ )
275
338
label = tf .squeeze (label , axis = 0 )
276
339
277
340
valid_mask = tf .not_equal (label , self ._ignore_label )
278
341
labels = {
279
342
'masks' : label ,
280
343
'valid_masks' : valid_mask ,
281
- 'image_info' : image_info
344
+ 'image_info' : image_info ,
282
345
}
283
346
284
347
# Cast image as self._dtype
0 commit comments