Skip to content

Commit 8f2c6a9

Browse files
No public description
PiperOrigin-RevId: 715200759
1 parent 65339fa commit 8f2c6a9

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

official/vision/configs/retinanet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class RetinaNetHead(hyperparams.Config):
137137
@dataclasses.dataclass
138138
class DetectionGenerator(hyperparams.Config):
139139
apply_nms: bool = True
140+
decode_boxes: bool = True
140141
pre_nms_top_k: int = 5000
141142
pre_nms_score_threshold: float = 0.05
142143
nms_iou_threshold: float = 0.5

official/vision/serving/detection.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def _normalize_coordinates(self, detections_dict, dict_keys, image_info):
136136

137137
return detections_dict
138138

139+
def _flatten_output(self, feature_map, feature_size=4):
140+
flatten_outputs = []
141+
for level_output in feature_map.values():
142+
flatten_outputs.append(
143+
tf.reshape(level_output, (self._batch_size, -1, feature_size))
144+
)
145+
return tf.concat(flatten_outputs, axis=1)
146+
139147
def preprocess(
140148
self, images: tf.Tensor
141149
) -> Tuple[tf.Tensor, Mapping[str, tf.Tensor], tf.Tensor]:
@@ -271,6 +279,18 @@ def serve(self, images: tf.Tensor):
271279
final_outputs['detection_outer_boxes'] = detections[
272280
'detection_outer_boxes'
273281
]
282+
elif (
283+
isinstance(self.params.task.model, configs.retinanet.RetinaNet)
284+
and not self.params.task.model.detection_generator.decode_boxes
285+
):
286+
final_outputs = {
287+
'raw_boxes': self._flatten_output(detections['box_outputs'], 4),
288+
'raw_scores': tf.sigmoid(
289+
self._flatten_output(
290+
detections['cls_outputs'], self.params.task.model.num_classes
291+
)
292+
),
293+
}
274294
else:
275295
# For RetinaNet model, apply export_config.
276296
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):

official/vision/serving/detection_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def _get_detection_module(
3838
normalized_coordinates=False,
3939
nms_version='batched',
4040
output_intermediate_features=False,
41+
decode_boxes=True,
4142
):
4243
params = exp_factory.get_exp_config(experiment_name)
4344
params.task.model.outer_boxes_scale = outer_boxes_scale
@@ -48,6 +49,8 @@ def _get_detection_module(
4849
params.task.model.detection_generator.nms_version = nms_version
4950
if output_intermediate_features:
5051
params.task.export_config.output_intermediate_features = True
52+
if not decode_boxes:
53+
params.task.model.detection_generator.decode_boxes = False
5154
detection_module = detection.DetectionModule(
5255
params,
5356
batch_size=1,
@@ -232,6 +235,41 @@ def test_export_normalized_coordinates_no_nms(
232235
max_values.numpy(), tf.ones_like(max_values).numpy()
233236
)
234237

238+
@parameterized.parameters(
239+
'retinanet_mobile_coco',
240+
'retinanet_spinenet_coco',
241+
)
242+
def test_export_without_decoding_boxes(
243+
self,
244+
experiment_name,
245+
):
246+
input_type = 'tflite'
247+
tmp_dir = self.get_temp_dir()
248+
module = self._get_detection_module(
249+
experiment_name,
250+
input_type=input_type,
251+
apply_nms=False,
252+
decode_boxes=False,
253+
)
254+
255+
self._export_from_module(module, input_type, tmp_dir)
256+
257+
imported = tf.saved_model.load(tmp_dir)
258+
detection_fn = imported.signatures['serving_default']
259+
260+
images = self._get_dummy_input(
261+
input_type, batch_size=1, image_size=(640, 640)
262+
)
263+
outputs = detection_fn(tf.constant(images))
264+
265+
self.assertContainsSubset(
266+
{
267+
'raw_boxes',
268+
'raw_scores',
269+
},
270+
outputs.keys(),
271+
)
272+
235273

236274
if __name__ == '__main__':
237275
tf.test.main()

0 commit comments

Comments
 (0)