Skip to content

Commit f4f6971

Browse files
feat(multi_object_tracker): add object class filtering in tracking process (autowarefoundation#6607)
* feat: object class filter Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * fix: set a member private Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * fix: last filtered label is not useful, remove Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * style(pre-commit): autofix Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * fix: multiply gain for new class Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * style(pre-commit): autofix Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * chore: algorithm explanation Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * fix: revise the filtering process flow Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> --------- Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 86aaa08 commit f4f6971

File tree

4 files changed

+65
-2
lines changed

4 files changed

+65
-2
lines changed

perception/multi_object_tracker/include/multi_object_tracker/tracker/model/tracker_base.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class Tracker
4242
{
4343
classification_ = classification;
4444
}
45+
void updateClassification(
46+
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification);
4547

4648
private:
4749
unique_identifier_msgs::msg::UUID uuid_;

perception/multi_object_tracker/src/tracker/model/multiple_vehicle_tracker.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ bool MultipleVehicleTracker::measure(
4343
big_vehicle_tracker_.measure(object, time, self_transform);
4444
normal_vehicle_tracker_.measure(object, time, self_transform);
4545
if (object_recognition_utils::getHighestProbLabel(object.classification) != Label::UNKNOWN)
46-
setClassification(object.classification);
46+
updateClassification(object.classification);
4747
return true;
4848
}
4949

perception/multi_object_tracker/src/tracker/model/pedestrian_and_bicycle_tracker.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ bool PedestrianAndBicycleTracker::measure(
4343
pedestrian_tracker_.measure(object, time, self_transform);
4444
bicycle_tracker_.measure(object, time, self_transform);
4545
if (object_recognition_utils::getHighestProbLabel(object.classification) != Label::UNKNOWN)
46-
setClassification(object.classification);
46+
updateClassification(object.classification);
4747
return true;
4848
}
4949

perception/multi_object_tracker/src/tracker/model/tracker_base.cpp

+61
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,67 @@ bool Tracker::updateWithoutMeasurement()
5454
return true;
5555
}
5656

57+
void Tracker::updateClassification(
58+
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification)
59+
{
60+
// classification algorithm:
61+
// 0. Normalize the input classification
62+
// 1-1. Update the matched classification probability with a gain (ratio of 0.05)
63+
// 1-2. If the label is not found, add it to the classification list
64+
// 2. Remove the class with probability < remove_threshold (0.001)
65+
// 3. Normalize tracking classification
66+
67+
// Parameters
68+
// if the remove_threshold is too high (compare to the gain), the classification will be removed
69+
// immediately
70+
const double gain = 0.05;
71+
constexpr double remove_threshold = 0.001;
72+
73+
// Normalization function
74+
auto normalizeProbabilities =
75+
[](std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification) {
76+
double sum = 0.0;
77+
for (const auto & class_ : classification) {
78+
sum += class_.probability;
79+
}
80+
for (auto & class_ : classification) {
81+
class_.probability /= sum;
82+
}
83+
};
84+
85+
// Normalize the input
86+
auto classification_input = classification;
87+
normalizeProbabilities(classification_input);
88+
89+
// Update the matched classification probability with a gain
90+
for (const auto & new_class : classification_input) {
91+
bool found = false;
92+
for (auto & old_class : classification_) {
93+
if (new_class.label == old_class.label) {
94+
old_class.probability += new_class.probability * gain;
95+
found = true;
96+
break;
97+
}
98+
}
99+
// If the label is not found, add it to the classification list
100+
if (!found) {
101+
auto adding_class = new_class;
102+
adding_class.probability *= gain;
103+
classification_.push_back(adding_class);
104+
}
105+
}
106+
107+
// If the probability is less than the threshold, remove the class
108+
classification_.erase(
109+
std::remove_if(
110+
classification_.begin(), classification_.end(),
111+
[remove_threshold](const auto & class_) { return class_.probability < remove_threshold; }),
112+
classification_.end());
113+
114+
// Normalize tracking classification
115+
normalizeProbabilities(classification_);
116+
}
117+
57118
geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance(
58119
const rclcpp::Time & time) const
59120
{

0 commit comments

Comments
 (0)