Skip to content

Commit 247331d

Browse files
Support custom anchors with different anchors per location by level
PiperOrigin-RevId: 611616542
1 parent 3fb2306 commit 247331d

File tree

5 files changed

+157
-34
lines changed

5 files changed

+157
-34
lines changed

official/vision/modeling/factory.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Factory methods to build models."""
1616

17-
from typing import Optional
17+
from typing import Mapping, Optional
1818

1919
import tensorflow as tf, tf_keras
2020

@@ -262,9 +262,28 @@ def build_retinanet(
262262
model_config: retinanet_cfg.RetinaNet,
263263
l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
264264
backbone: Optional[tf_keras.Model] = None,
265-
decoder: Optional[tf_keras.Model] = None
265+
decoder: Optional[tf_keras.Model] = None,
266+
num_anchors_per_location: int | dict[str, int] | None = None,
267+
anchor_boxes: Mapping[str, tf.Tensor] | None = None,
266268
) -> tf_keras.Model:
267-
"""Builds RetinaNet model."""
269+
"""Builds a RetinaNet model.
270+
271+
Args:
272+
input_specs: The InputSpec of the input image tensor to the model.
273+
model_config: The RetinaNet model configuration to build from.
274+
l2_regularizer: Optional l2 regularizer to use for building the backbone,
275+
decorder, and head.
276+
backbone: Optional instance of the backbone model.
277+
decoder: Optional instance of the decoder model.
278+
num_anchors_per_location: Optional number of anchors per pixel location for
279+
building the RetinaNetHead. If an `int`, the same number is used for all
280+
levels. If a `dict`, it specifies the number at each level. If `none`, it
281+
uses `len(aspect_ratios) * num_scales` from the anchor config by default.
282+
anchor_boxes: Optional fixed multilevel anchor boxes for inference.
283+
284+
Returns:
285+
RetinaNet model.
286+
"""
268287
norm_activation_config = model_config.norm_activation
269288
if not backbone:
270289
backbone = backbones.factory.build_backbone(
@@ -282,7 +301,7 @@ def build_retinanet(
282301

283302
head_config = model_config.head
284303
generator_config = model_config.detection_generator
285-
num_anchors_per_location = (
304+
num_anchors_per_location = num_anchors_per_location or (
286305
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
287306

288307
head = dense_prediction_heads.RetinaNetHead(
@@ -333,16 +352,26 @@ def build_retinanet(
333352
box_coder_weights=generator_config.box_coder_weights,
334353
)
335354

355+
num_scales = None
356+
aspect_ratios = None
357+
anchor_size = None
358+
if anchor_boxes is None:
359+
num_scales = model_config.anchor.num_scales
360+
aspect_ratios = model_config.anchor.aspect_ratios
361+
anchor_size = model_config.anchor.anchor_size
362+
336363
model = retinanet_model.RetinaNetModel(
337364
backbone,
338365
decoder,
339366
head,
340367
detection_generator_obj,
368+
anchor_boxes=anchor_boxes,
341369
min_level=model_config.min_level,
342370
max_level=model_config.max_level,
343-
num_scales=model_config.anchor.num_scales,
344-
aspect_ratios=model_config.anchor.aspect_ratios,
345-
anchor_size=model_config.anchor.anchor_size)
371+
num_scales=num_scales,
372+
aspect_ratios=aspect_ratios,
373+
anchor_size=anchor_size,
374+
)
346375
return model
347376

348377

official/vision/modeling/factory_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# limitations under the License.
1414

1515
"""Tests for factory.py."""
16+
import collections
1617

1718
# Import libraries
1819
from absl.testing import parameterized
1920
import tensorflow as tf, tf_keras
2021

2122
from official.vision.configs import backbones
2223
from official.vision.configs import backbones_3d
24+
from official.vision.configs import decoders
2325
from official.vision.configs import image_classification as classification_cfg
2426
from official.vision.configs import maskrcnn as maskrcnn_cfg
2527
from official.vision.configs import retinanet as retinanet_cfg
@@ -123,6 +125,47 @@ def test_builder(self, backbone_type, input_size, has_att_heads):
123125
),
124126
)
125127

128+
def test_build_model_with_custom_anchors_can_run(self):
129+
image_size = (16, 16)
130+
input_specs = tf_keras.layers.InputSpec(shape=[None, *image_size, 3])
131+
model_config = retinanet_cfg.RetinaNet(
132+
num_classes=5,
133+
min_level=3,
134+
max_level=4,
135+
decoder=decoders.Decoder(type='identity'),
136+
head=retinanet_cfg.RetinaNetHead(
137+
num_convs=0, share_level_convs=False,
138+
)
139+
)
140+
anchor_boxes = collections.OrderedDict()
141+
anchor_boxes['3'] = tf.constant(
142+
[
143+
[[3, 4, 5, 6], [3, 4, 5, 6]],
144+
[[3, 4, 5, 6], [3, 4, 5, 6]],
145+
],
146+
dtype=tf.float32,
147+
)
148+
anchor_boxes['4'] = tf.constant(
149+
[[[3, 4, 5, 6, 3, 4, 5, 6]]], dtype=tf.float32
150+
)
151+
model = factory.build_retinanet(
152+
input_specs=input_specs,
153+
model_config=model_config,
154+
anchor_boxes=anchor_boxes,
155+
num_anchors_per_location={'3': 1, '4': 2},
156+
)
157+
test_input = tf.zeros([2, *image_size, 3])
158+
outputs = model.call(test_input)
159+
self.assertIn('box_outputs', outputs)
160+
self.assertIn('3', outputs['box_outputs'])
161+
self.assertIn('4', outputs['box_outputs'])
162+
self.assertAllEqual(
163+
outputs['box_outputs']['3'].numpy().shape, [2, 2, 2, 4 * 1]
164+
)
165+
self.assertAllEqual(
166+
outputs['box_outputs']['4'].numpy().shape, [2, 1, 1, 4 * 2]
167+
)
168+
126169

127170
class VideoClassificationModelBuilderTest(parameterized.TestCase,
128171
tf.test.TestCase):

official/vision/modeling/heads/dense_prediction_heads.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
min_level: int,
3434
max_level: int,
3535
num_classes: int,
36-
num_anchors_per_location: int,
36+
num_anchors_per_location: int | dict[str, int],
3737
num_convs: int = 4,
3838
num_filters: int = 256,
3939
attribute_heads: Optional[List[Dict[str, Any]]] = None,
@@ -55,7 +55,9 @@ def __init__(
5555
min_level: An `int` number of minimum feature level.
5656
max_level: An `int` number of maximum feature level.
5757
num_classes: An `int` number of classes to predict.
58-
num_anchors_per_location: An `int` number of anchors per pixel location.
58+
num_anchors_per_location: Number of anchors per pixel location. If an
59+
`int`, the same number is used for all levels. If a `dict`, it specifies
60+
the number at each level.
5961
num_convs: An `int` number that represents the number of the intermediate
6062
conv layers before the prediction.
6163
num_filters: An `int` number that represents the number of filters of the
@@ -134,15 +136,21 @@ def __init__(
134136
}
135137

136138
self._classifier_kwargs = {
137-
'filters': (
138-
self._config_dict['num_classes']
139-
* self._config_dict['num_anchors_per_location']
140-
),
141139
'kernel_size': 3,
142140
'padding': 'same',
143141
'bias_initializer': tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
144142
'bias_regularizer': self._config_dict['bias_regularizer'],
145143
}
144+
if isinstance(self._config_dict['num_anchors_per_location'], dict):
145+
self._classifier_kwargs['filters'] = {
146+
level: v * self._config_dict['num_classes']
147+
for level, v in self._config_dict['num_anchors_per_location'].items()
148+
}
149+
else:
150+
self._classifier_kwargs['filters'] = (
151+
self._config_dict['num_classes']
152+
* self._config_dict['num_anchors_per_location']
153+
)
146154
if self._config_dict['use_separable_conv']:
147155
self._classifier_kwargs.update({
148156
'depthwise_initializer': tf_keras.initializers.RandomNormal(
@@ -161,15 +169,21 @@ def __init__(
161169
})
162170

163171
self._box_regressor_kwargs = {
164-
'filters': (
165-
self._config_dict['num_params_per_anchor']
166-
* self._config_dict['num_anchors_per_location']
167-
),
168172
'kernel_size': 3,
169173
'padding': 'same',
170174
'bias_initializer': tf.zeros_initializer(),
171175
'bias_regularizer': self._config_dict['bias_regularizer'],
172176
}
177+
if isinstance(self._config_dict['num_anchors_per_location'], dict):
178+
self._box_regressor_kwargs['filters'] = {
179+
level: v * self._config_dict['num_params_per_anchor']
180+
for level, v in self._config_dict['num_anchors_per_location'].items()
181+
}
182+
else:
183+
self._box_regressor_kwargs['filters'] = (
184+
self._config_dict['num_params_per_anchor']
185+
* self._config_dict['num_anchors_per_location']
186+
)
173187
if self._config_dict['use_separable_conv']:
174188
self._box_regressor_kwargs.update({
175189
'depthwise_initializer': tf_keras.initializers.RandomNormal(
@@ -341,9 +355,16 @@ def _build_prediction_tower(
341355
for level in range(
342356
self._config_dict['min_level'], self._config_dict['max_level'] + 1
343357
):
344-
predictor_kwargs = self._conv_kwargs_new_kernel_init(predictor_kwargs)
358+
predictor_kwargs_level = predictor_kwargs.copy()
359+
if isinstance(predictor_kwargs_level['filters'], dict):
360+
predictor_kwargs_level['filters'] = predictor_kwargs_level['filters'][
361+
str(level)
362+
]
363+
predictor_kwargs_level = self._conv_kwargs_new_kernel_init(
364+
predictor_kwargs_level
365+
)
345366
predictors.append(
346-
conv_op(name=f'{predictor_name}-{level}', **predictor_kwargs)
367+
conv_op(name=f'{predictor_name}-{level}', **predictor_kwargs_level)
347368
)
348369

349370
return convs, norms, predictors

official/vision/modeling/heads/dense_prediction_heads_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,28 @@ def test_forward_shared_prediction_tower_with_share_classification_heads(
181181
}
182182
retinanet_head(features)
183183

184+
def test_forward_with_num_anchors_per_location_by_level(self):
185+
bs = 2
186+
retinanet_head = dense_prediction_heads.RetinaNetHead(
187+
min_level=3,
188+
max_level=4,
189+
num_classes=7,
190+
num_anchors_per_location={'3': 2, '4': 5},
191+
num_convs=0,
192+
num_filters=123,
193+
attribute_heads=None,
194+
share_level_convs=False,
195+
)
196+
features = {
197+
'3': np.random.rand(bs, 32, 32, 11),
198+
'4': np.random.rand(bs, 16, 16, 13),
199+
}
200+
scores, boxes, _ = retinanet_head(features)
201+
self.assertAllEqual(scores['3'].numpy().shape, [bs, 32, 32, 2 * 7])
202+
self.assertAllEqual(boxes['3'].numpy().shape, [bs, 32, 32, 2 * 4])
203+
self.assertAllEqual(scores['4'].numpy().shape, [bs, 16, 16, 5 * 7])
204+
self.assertAllEqual(boxes['4'].numpy().shape, [bs, 16, 16, 5 * 4])
205+
184206
def test_serialize_deserialize(self):
185207
retinanet_head = dense_prediction_heads.RetinaNetHead(
186208
min_level=3,

official/vision/modeling/retinanet_model.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""RetinaNet."""
16+
import collections
1617
from typing import Any, Mapping, List, Optional, Union, Sequence
1718

1819
# Import libraries
@@ -30,6 +31,7 @@ def __init__(self,
3031
decoder: tf_keras.Model,
3132
head: tf_keras.layers.Layer,
3233
detection_generator: tf_keras.layers.Layer,
34+
anchor_boxes: Mapping[str, tf.Tensor] | None = None,
3335
min_level: Optional[int] = None,
3436
max_level: Optional[int] = None,
3537
num_scales: Optional[int] = None,
@@ -43,6 +45,12 @@ def __init__(self,
4345
decoder: `tf_keras.Model` a decoder network.
4446
head: `RetinaNetHead`, the RetinaNet head.
4547
detection_generator: the detection generator.
48+
anchor_boxes: a dict of tensors which includes multilevel anchors.
49+
- key: `str`, the level of the multilevel predictions.
50+
- values: `Tensor`, the anchor coordinates of a particular feature
51+
level, whose shape is [height_l, width_l, 4 *
52+
num_anchors_per_location_l].
53+
If provided, these anchors will be used for inference (training=False).
4654
min_level: Minimum level in output feature maps.
4755
max_level: Maximum level in output feature maps.
4856
num_scales: A number representing intermediate scales added
@@ -72,11 +80,12 @@ def __init__(self,
7280
self._decoder = decoder
7381
self._head = head
7482
self._detection_generator = detection_generator
83+
self._anchor_boxes = anchor_boxes
7584

7685
def call(self,
7786
images: Union[tf.Tensor, Sequence[tf.Tensor]],
7887
image_shape: Optional[tf.Tensor] = None,
79-
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
88+
anchor_boxes: Mapping[str, tf.Tensor] | None = None,
8089
output_intermediate_features: bool = False,
8190
training: bool = None) -> Mapping[str, tf.Tensor]:
8291
"""Forward pass of the RetinaNet model.
@@ -91,10 +100,8 @@ def call(self,
91100
this is the actual image shape excluding paddings. For example, images
92101
in the batch may be resized into different shapes before padding to the
93102
fixed size.
94-
anchor_boxes: a dict of tensors which includes multilevel anchors.
95-
- key: `str`, the level of the multilevel predictions.
96-
- values: `Tensor`, the anchor coordinates of a particular feature
97-
level, whose shape is [height_l, width_l, num_anchors_per_location].
103+
anchor_boxes: the anchor boxes to use for inference (training=False) if
104+
not provided in the init.
98105
output_intermediate_features: `bool` indicating whether to return the
99106
intermediate feature maps generated by backbone and decoder.
100107
training: `bool`, indicating whether it is in training mode.
@@ -131,18 +138,23 @@ def call(self,
131138

132139
# Dense prediction. `raw_attributes` can be empty.
133140
raw_scores, raw_boxes, raw_attributes = self.head(features)
141+
outputs.update({
142+
'cls_outputs': raw_scores,
143+
'box_outputs': raw_boxes,
144+
})
134145

135146
if training:
136-
outputs.update({
137-
'cls_outputs': raw_scores,
138-
'box_outputs': raw_boxes,
139-
})
140147
if raw_attributes:
141148
outputs.update({'attribute_outputs': raw_attributes})
142149
return outputs
143150
else:
144-
# Generate anchor boxes for this batch if not provided.
145-
if anchor_boxes is None:
151+
if self._anchor_boxes is not None:
152+
batch_size = tf.shape(raw_boxes[str(self._config_dict['min_level'])])[0]
153+
anchor_boxes = collections.OrderedDict()
154+
for level, boxes in self._anchor_boxes.items():
155+
anchor_boxes[level] = tf.tile(boxes[None, ...], [batch_size, 1, 1, 1])
156+
elif anchor_boxes is None:
157+
# Generate anchor boxes for this batch if not provided.
146158
if isinstance(images, Sequence):
147159
primary_images = images[0]
148160
elif isinstance(images, tf.Tensor):
@@ -169,10 +181,6 @@ def call(self,
169181
final_results = self.detection_generator(raw_boxes, raw_scores,
170182
anchor_boxes, image_shape,
171183
raw_attributes)
172-
outputs.update({
173-
'cls_outputs': raw_scores,
174-
'box_outputs': raw_boxes,
175-
})
176184

177185
def _update_decoded_results():
178186
outputs.update({

0 commit comments

Comments
 (0)