Skip to content

Commit 302e71a

Browse files
committedMar 19, 2024
Fix inconsistent anchor boxes generation between train and inference
* Currently, `AnchorGenerator` and `Anchor` can actually generate different anchor boxes when input image size is not divisible by 2^max_level. This results in inconsistent training and inference box predictions. This change fix the inconsistency by calling only the Anchor class. * Fix the case when input image is not square. * Refactor the Anchor class to only generate and store `multilevel_boxes` since the flatten boxes are never used. PiperOrigin-RevId: 616997347
1 parent 505c718 commit 302e71a

File tree

4 files changed

+169
-105
lines changed

4 files changed

+169
-105
lines changed
 

‎official/vision/ops/anchor.py

Lines changed: 110 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import tensorflow as tf, tf_keras
2424

25-
from official.vision.ops import anchor_generator
2625
from official.vision.ops import box_matcher
2726
from official.vision.ops import iou_similarity
2827
from official.vision.ops import target_gather
@@ -32,7 +31,38 @@
3231

3332

3433
class Anchor(object):
35-
"""Anchor class for anchor-based object detectors."""
34+
"""Anchor class for anchor-based object detectors.
35+
36+
Example:
37+
```python
38+
anchor_boxes = Anchor(
39+
min_level=3,
40+
max_level=4,
41+
num_scales=2,
42+
aspect_ratios=[0.5, 1., 2.],
43+
anchor_size=4.,
44+
image_size=[256, 256],
45+
).multilevel_boxes
46+
```
47+
48+
Attributes:
49+
min_level: integer number of minimum level of the output feature pyramid.
50+
max_level: integer number of maximum level of the output feature pyramid.
51+
num_scales: integer number representing intermediate scales added on each
52+
level. For instances, num_scales=2 adds one additional intermediate
53+
anchor scales [2^0, 2^0.5] on each level.
54+
aspect_ratios: list of float numbers representing the aspect ratio anchors
55+
added on each level. The number indicates the ratio of width to height.
56+
For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each
57+
scale level.
58+
anchor_size: float number representing the scale of size of the base
59+
anchor to the feature stride 2^level.
60+
image_size: a list of integer numbers or Tensors representing [height,
61+
width] of the input image size.
62+
multilevel_boxes: an OrderedDict from level to the generated anchor boxes of
63+
shape [height_l, width_l, num_anchors_per_location * 4].
64+
anchors_per_location: number of anchors per pixel location.
65+
"""
3666

3767
def __init__(
3868
self,
@@ -43,57 +73,40 @@ def __init__(
4373
anchor_size,
4474
image_size,
4575
):
46-
"""Constructs multi-scale anchors.
47-
48-
Args:
49-
min_level: integer number of minimum level of the output feature pyramid.
50-
max_level: integer number of maximum level of the output feature pyramid.
51-
num_scales: integer number representing intermediate scales added on each
52-
level. For instances, num_scales=2 adds one additional intermediate
53-
anchor scales [2^0, 2^0.5] on each level.
54-
aspect_ratios: list of float numbers representing the aspect ratio anchors
55-
added on each level. The number indicates the ratio of width to height.
56-
For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each
57-
scale level.
58-
anchor_size: float number representing the scale of size of the base
59-
anchor to the feature stride 2^level.
60-
image_size: a list of integer numbers or Tensors representing [height,
61-
width] of the input image size.The image_size should be divided by the
62-
largest feature stride 2^max_level.
63-
"""
76+
"""Initializes the instance."""
6477
self.min_level = min_level
6578
self.max_level = max_level
6679
self.num_scales = num_scales
6780
self.aspect_ratios = aspect_ratios
6881
self.anchor_size = anchor_size
6982
self.image_size = image_size
70-
self.boxes = self._generate_boxes()
83+
self.multilevel_boxes = self._generate_multilevel_boxes()
7184

72-
def _generate_boxes(self) -> tf.Tensor:
85+
def _generate_multilevel_boxes(self) -> Dict[str, tf.Tensor]:
7386
"""Generates multi-scale anchor boxes.
7487

7588
Returns:
76-
a Tensor of shape [N, 4], representing anchor boxes of all levels
77-
concatenated together.
89+
An OrderedDict from level to anchor boxes of shape [height_l, width_l,
90+
num_anchors_per_location * 4].
7891
"""
79-
boxes_all = []
92+
multilevel_boxes = collections.OrderedDict()
8093
for level in range(self.min_level, self.max_level + 1):
8194
boxes_l = []
82-
feat_size = math.ceil(self.image_size[0] / 2**level)
83-
stride = tf.cast(self.image_size[0] / feat_size, tf.float32)
95+
feat_size_y = math.ceil(self.image_size[0] / 2**level)
96+
feat_size_x = math.ceil(self.image_size[1] / 2**level)
97+
stride_y = tf.cast(self.image_size[0] / feat_size_y, tf.float32)
98+
stride_x = tf.cast(self.image_size[1] / feat_size_x, tf.float32)
99+
x = tf.range(stride_x / 2, self.image_size[1], stride_x)
100+
y = tf.range(stride_y / 2, self.image_size[0], stride_y)
101+
xv, yv = tf.meshgrid(x, y)
84102
for scale in range(self.num_scales):
85103
for aspect_ratio in self.aspect_ratios:
86-
intermidate_scale = 2 ** (scale / float(self.num_scales))
87-
base_anchor_size = self.anchor_size * stride * intermidate_scale
104+
intermidate_scale = 2 ** (scale / self.num_scales)
105+
base_anchor_size = self.anchor_size * 2**level * intermidate_scale
88106
aspect_x = aspect_ratio**0.5
89107
aspect_y = aspect_ratio**-0.5
90108
half_anchor_size_x = base_anchor_size * aspect_x / 2.0
91109
half_anchor_size_y = base_anchor_size * aspect_y / 2.0
92-
x = tf.range(stride / 2, self.image_size[1], stride)
93-
y = tf.range(stride / 2, self.image_size[0], stride)
94-
xv, yv = tf.meshgrid(x, y)
95-
xv = tf.cast(tf.reshape(xv, [-1]), dtype=tf.float32)
96-
yv = tf.cast(tf.reshape(yv, [-1]), dtype=tf.float32)
97110
# Tensor shape Nx4.
98111
boxes = tf.stack(
99112
[
@@ -102,41 +115,18 @@ def _generate_boxes(self) -> tf.Tensor:
102115
yv + half_anchor_size_y,
103116
xv + half_anchor_size_x,
104117
],
105-
axis=1,
118+
axis=-1,
106119
)
107120
boxes_l.append(boxes)
108-
# Concat anchors on the same level to tensor shape NxAx4.
109-
boxes_l = tf.stack(boxes_l, axis=1)
110-
boxes_l = tf.reshape(boxes_l, [-1, 4])
111-
boxes_all.append(boxes_l)
112-
return tf.concat(boxes_all, axis=0)
113-
114-
def unpack_labels(self, labels: tf.Tensor) -> Dict[str, tf.Tensor]:
115-
"""Unpacks an array of labels into multi-scales labels."""
116-
unpacked_labels = collections.OrderedDict()
117-
count = 0
118-
for level in range(self.min_level, self.max_level + 1):
119-
feat_size_y = tf.cast(
120-
math.ceil(self.image_size[0] / 2**level), tf.int32
121-
)
122-
feat_size_x = tf.cast(
123-
math.ceil(self.image_size[1] / 2**level), tf.int32
124-
)
125-
steps = feat_size_y * feat_size_x * self.anchors_per_location
126-
unpacked_labels[str(level)] = tf.reshape(
127-
labels[count : count + steps], [feat_size_y, feat_size_x, -1]
128-
)
129-
count += steps
130-
return unpacked_labels
121+
# Concat anchors on the same level to tensor shape HxWx(Ax4).
122+
boxes_l = tf.concat(boxes_l, axis=-1)
123+
multilevel_boxes[str(level)] = boxes_l
124+
return multilevel_boxes
131125

132126
@property
133-
def anchors_per_location(self):
127+
def anchors_per_location(self) -> int:
134128
return self.num_scales * len(self.aspect_ratios)
135129

136-
@property
137-
def multilevel_boxes(self):
138-
return self.unpack_labels(self.boxes)
139-
140130

141131
class AnchorLabeler(object):
142132
"""Labeler for dense object detector."""
@@ -420,24 +410,68 @@ def label_anchors( # pytype: disable=signature-mismatch # overriding-parameter
420410
return score_targets_dict, box_targets_dict
421411

422412

413+
class AnchorGeneratorv2:
414+
"""Utility to generate anchors for a multiple feature maps.
415+
416+
Attributes:
417+
min_level: integer number of minimum level of the output feature pyramid.
418+
max_level: integer number of maximum level of the output feature pyramid.
419+
num_scales: integer number representing intermediate scales added on each
420+
level. For instances, num_scales=2 adds one additional intermediate
421+
anchor scales [2^0, 2^0.5] on each level.
422+
aspect_ratios: list of float numbers representing the aspect ratio anchors
423+
added on each level. The number indicates the ratio of width to height.
424+
For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each
425+
scale level.
426+
anchor_size: float number representing the scale of size of the base
427+
anchor to the feature stride 2^level.
428+
"""
429+
430+
def __init__(
431+
self,
432+
min_level,
433+
max_level,
434+
num_scales,
435+
aspect_ratios,
436+
anchor_size,
437+
):
438+
"""Initializes the instance."""
439+
self.min_level = min_level
440+
self.max_level = max_level
441+
self.num_scales = num_scales
442+
self.aspect_ratios = aspect_ratios
443+
self.anchor_size = anchor_size
444+
445+
def __call__(self, image_size):
446+
"""Generate multilevel anchor boxes.
447+
448+
Args:
449+
image_size: a list of integer numbers or Tensors representing [height,
450+
width] of the input image size.
451+
Returns:
452+
An ordered dictionary from level to anchor boxes of shape [height_l,
453+
width_l, num_anchors_per_location * 4].
454+
"""
455+
return Anchor(
456+
min_level=self.min_level,
457+
max_level=self.max_level,
458+
num_scales=self.num_scales,
459+
aspect_ratios=self.aspect_ratios,
460+
anchor_size=self.anchor_size,
461+
image_size=image_size,
462+
).multilevel_boxes
463+
464+
423465
def build_anchor_generator(
424466
min_level, max_level, num_scales, aspect_ratios, anchor_size
425467
):
426468
"""Build anchor generator from levels."""
427-
anchor_sizes = collections.OrderedDict()
428-
strides = collections.OrderedDict()
429-
scales = []
430-
for scale in range(num_scales):
431-
scales.append(2 ** (scale / float(num_scales)))
432-
for level in range(min_level, max_level + 1):
433-
stride = 2**level
434-
strides[str(level)] = stride
435-
anchor_sizes[str(level)] = anchor_size * stride
436-
anchor_gen = anchor_generator.AnchorGenerator(
437-
anchor_sizes=anchor_sizes,
438-
scales=scales,
469+
anchor_gen = AnchorGeneratorv2(
470+
min_level=min_level,
471+
max_level=max_level,
472+
num_scales=num_scales,
439473
aspect_ratios=aspect_ratios,
440-
strides=strides,
474+
anchor_size=anchor_size,
441475
)
442476
return anchor_gen
443477

‎official/vision/ops/anchor_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __call__(self, image_size):
109109
return tf.reshape(result, [shape[0], shape[1], shape[2] * shape[3]])
110110

111111

112-
class AnchorGenerator():
112+
class AnchorGeneratorv1():
113113
"""Utility to generate anchors for a multiple feature maps.
114114

115115
Example:

‎official/vision/ops/anchor_generator_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def testAnchorGeneration(self, min_level, max_level, aspect_ratios,
7777
levels = range(min_level, max_level + 1)
7878
anchor_sizes = [2**(level + 1) for level in levels]
7979
strides = [2**level for level in levels]
80-
anchor_gen = anchor_generator.AnchorGenerator(
80+
anchor_gen = anchor_generator.AnchorGeneratorv1(
8181
anchor_sizes=anchor_sizes,
8282
scales=[1.],
8383
aspect_ratios=aspect_ratios,
@@ -98,7 +98,7 @@ def testAnchorGenerationClipped(self, min_level, max_level, aspect_ratios,
9898
levels = range(min_level, max_level + 1)
9999
anchor_sizes = [2**(level + 1) for level in levels]
100100
strides = [2**level for level in levels]
101-
anchor_gen = anchor_generator.AnchorGenerator(
101+
anchor_gen = anchor_generator.AnchorGeneratorv1(
102102
anchor_sizes=anchor_sizes,
103103
scales=[1.],
104104
aspect_ratios=aspect_ratios,
@@ -122,7 +122,7 @@ def testAnchorGenerationDict(self, min_level, max_level, aspect_ratios,
122122
levels = range(min_level, max_level + 1)
123123
anchor_sizes = dict((str(level), 2**(level + 1)) for level in levels)
124124
strides = dict((str(level), 2**level) for level in levels)
125-
anchor_gen = anchor_generator.AnchorGenerator(
125+
anchor_gen = anchor_generator.AnchorGeneratorv1(
126126
anchor_sizes=anchor_sizes,
127127
scales=[1.],
128128
aspect_ratios=aspect_ratios,

‎official/vision/ops/anchor_test.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,40 +58,52 @@ def testAnchorRpnSample(self, num_anchors, num_positives,
5858
self.assertEqual(negatives, expected_negatives)
5959

6060
@parameterized.parameters(
61-
# Single scale anchor.
62-
(5, 5, 1, [1.0], 2.0,
63-
[[-16, -16, 48, 48], [-16, 16, 48, 80],
64-
[16, -16, 80, 48], [16, 16, 80, 80]]),
65-
# Multi scale anchor.
66-
(5, 6, 1, [1.0], 2.0,
67-
[[-16, -16, 48, 48], [-16, 16, 48, 80],
68-
[16, -16, 80, 48], [16, 16, 80, 80], [-32, -32, 96, 96]]),
69-
# # Multi aspect ratio anchor.
70-
(6, 6, 1, [1.0, 4.0, 0.25], 2.0,
71-
[[-32, -32, 96, 96], [-0, -96, 64, 160], [-96, -0, 160, 64]]),
72-
61+
# Single scale anchor
62+
(5, 5, 1, [1.0], 2.0, [64, 64],
63+
{'5': [[[-16, -16, 48, 48], [-16, 16, 48, 80]],
64+
[[16, -16, 80, 48], [16, 16, 80, 80]]]}),
65+
# Multi scale anchor
66+
(5, 6, 1, [1.0], 2.0, [64, 64],
67+
{'5': [[[-16, -16, 48, 48], [-16, 16, 48, 80]],
68+
[[16, -16, 80, 48], [16, 16, 80, 80]]],
69+
'6': [[[-32, -32, 96, 96]]]}),
70+
# Multi aspect ratio anchor
71+
(6, 6, 1, [1.0, 4.0, 0.25], 2.0, [64, 64],
72+
{'6': [[[-32, -32, 96, 96, -0, -96, 64, 160, -96, -0, 160, 64]]]}),
73+
# Intermidate scales
74+
(5, 5, 2, [1.0], 1.0, [32, 32],
75+
{'5': [[[0, 0, 32, 32,
76+
16 - 16 * 2**0.5, 16 - 16 * 2**0.5,
77+
16 + 16 * 2**0.5, 16 + 16 * 2**0.5]]]}),
78+
# Non-square
79+
(5, 5, 1, [1.0], 1.0, [64, 32],
80+
{'5': [[[0, 0, 32, 32]],
81+
[[32, 0, 64, 32]]]}),
82+
# Indivisible by 2^level
83+
(5, 5, 1, [1.0], 1.0, [40, 32],
84+
{'5': [[[-6, 0, 26, 32]],
85+
[[14, 0, 46, 32]]]}),
7386
)
7487
def testAnchorGeneration(self, min_level, max_level, num_scales,
75-
aspect_ratios, anchor_size, expected_boxes):
76-
image_size = [64, 64]
88+
aspect_ratios, anchor_size, image_size,
89+
expected_boxes):
7790
anchors = anchor.Anchor(min_level, max_level, num_scales, aspect_ratios,
7891
anchor_size, image_size)
79-
boxes = anchors.boxes.numpy()
80-
self.assertEqual(expected_boxes, boxes.tolist())
92+
self.assertAllClose(expected_boxes, anchors.multilevel_boxes)
8193

8294
@parameterized.parameters(
8395
# Single scale anchor.
8496
(5, 5, 1, [1.0], 2.0,
85-
[[-16, -16, 48, 48], [-16, 16, 48, 80],
86-
[16, -16, 80, 48], [16, 16, 80, 80]]),
97+
{'5': [[[-16, -16, 48, 48], [-16, 16, 48, 80]],
98+
[[16, -16, 80, 48], [16, 16, 80, 80]]]}),
8799
# Multi scale anchor.
88100
(5, 6, 1, [1.0], 2.0,
89-
[[-16, -16, 48, 48], [-16, 16, 48, 80],
90-
[16, -16, 80, 48], [16, 16, 80, 80], [-32, -32, 96, 96]]),
91-
# # Multi aspect ratio anchor.
101+
{'5': [[[-16, -16, 48, 48], [-16, 16, 48, 80]],
102+
[[16, -16, 80, 48], [16, 16, 80, 80]]],
103+
'6': [[[-32, -32, 96, 96]]]}),
104+
# Multi aspect ratio anchor.
92105
(6, 6, 1, [1.0, 4.0, 0.25], 2.0,
93-
[[-32, -32, 96, 96], [-0, -96, 64, 160], [-96, -0, 160, 64]]),
94-
106+
{'6': [[[-32, -32, 96, 96, -0, -96, 64, 160, -96, -0, 160, 64]]]}),
95107
)
96108
def testAnchorGenerationWithImageSizeAsTensor(self,
97109
min_level,
@@ -103,8 +115,25 @@ def testAnchorGenerationWithImageSizeAsTensor(self,
103115
image_size = tf.constant([64, 64], tf.int32)
104116
anchors = anchor.Anchor(min_level, max_level, num_scales, aspect_ratios,
105117
anchor_size, image_size)
106-
boxes = anchors.boxes.numpy()
107-
self.assertEqual(expected_boxes, boxes.tolist())
118+
self.assertAllClose(expected_boxes, anchors.multilevel_boxes)
119+
120+
@parameterized.parameters(
121+
(6, 8, 2, [1.0, 2.0, 0.5], 3.0, [320, 256]),
122+
)
123+
def testAnchorGenerationAreCentered(self, min_level, max_level, num_scales,
124+
aspect_ratios, anchor_size, image_size):
125+
anchors = anchor.Anchor(min_level, max_level, num_scales, aspect_ratios,
126+
anchor_size, image_size)
127+
multilevel_boxes = anchors.multilevel_boxes
128+
image_size = np.array(image_size)
129+
for boxes in multilevel_boxes.values():
130+
boxes = boxes.numpy()
131+
box_centers = boxes.mean(axis=0).mean(axis=0)
132+
box_centers = [
133+
(box_centers[0] + box_centers[2]) / 2,
134+
(box_centers[1] + box_centers[3]) / 2,
135+
]
136+
self.assertAllClose(image_size / 2, box_centers)
108137

109138
@parameterized.parameters(
110139
(3, 6, 2, [1.0], 2.0, False),
@@ -164,6 +193,7 @@ def testLabelAnchors(self, min_level, max_level, num_scales, aspect_ratios,
164193
(3, 7, [.5, 1., 2.], 2, 8, (256, 256)),
165194
(3, 8, [1.], 3, 32, (512, 512)),
166195
(3, 3, [1.], 2, 4, (32, 32)),
196+
(4, 8, [.5, 1., 2.], 2, 3, (320, 256)),
167197
)
168198
def testEquivalentResult(self, min_level, max_level, aspect_ratios,
169199
num_scales, anchor_size, image_size):

0 commit comments

Comments
 (0)