Skip to content

Commit 4f97fc0

Browse files
1222-takeshiyhisakiyn-mrse
authored
fix(interpolation): fix spline bug (#1649)
* fix spline bug Signed-off-by: Y.Hisaki <yhisaki31@gmail.com> * revert: "feat: apply pose_instability_risk to localization_diag" (#1613) Revert "feat: apply pose_instability_risk to localization_diag (#1560)" This reverts commit 4bcb436. * revert: "feat(lane_departure_checker): add parameter of footprint_extra_margin" (#1614) Revert "feat(lane_departure_checker): add parameter of footprint_extra_margin…" This reverts commit 0245366. --------- Signed-off-by: Y.Hisaki <yhisaki31@gmail.com> Co-authored-by: Y.Hisaki <yhisaki31@gmail.com> Co-authored-by: Yuma Nihei <yuma.nihei@tier4.jp>
1 parent d01e8b3 commit 4f97fc0

File tree

4 files changed

+101
-145
lines changed

4 files changed

+101
-145
lines changed

common/interpolation/include/interpolation/spline_interpolation.hpp

+8-24
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,13 @@
1515
#ifndef INTERPOLATION__SPLINE_INTERPOLATION_HPP_
1616
#define INTERPOLATION__SPLINE_INTERPOLATION_HPP_
1717

18-
#include "interpolation/interpolation_utils.hpp"
19-
#include "tier4_autoware_utils/geometry/geometry.hpp"
18+
#include <Eigen/Dense>
2019

21-
#include <algorithm>
2220
#include <cmath>
23-
#include <iostream>
24-
#include <numeric>
2521
#include <vector>
2622

2723
namespace interpolation
2824
{
29-
// NOTE: X(s) = a_i (s - s_i)^3 + b_i (s - s_i)^2 + c_i (s - s_i) + d_i : (i = 0, 1, ... N-1)
30-
struct MultiSplineCoef
31-
{
32-
MultiSplineCoef() = default;
33-
34-
explicit MultiSplineCoef(const size_t num_spline)
35-
{
36-
a.resize(num_spline);
37-
b.resize(num_spline);
38-
c.resize(num_spline);
39-
d.resize(num_spline);
40-
}
41-
42-
std::vector<double> a;
43-
std::vector<double> b;
44-
std::vector<double> c;
45-
std::vector<double> d;
46-
};
4725

4826
// static spline interpolation functions
4927
std::vector<double> slerp(
@@ -84,8 +62,14 @@ class SplineInterpolation
8462
std::vector<double> getSplineInterpolatedDiffValues(const std::vector<double> & query_keys) const;
8563

8664
private:
65+
Eigen::VectorXd a_;
66+
Eigen::VectorXd b_;
67+
Eigen::VectorXd c_;
68+
Eigen::VectorXd d_;
69+
8770
std::vector<double> base_keys_;
88-
interpolation::MultiSplineCoef multi_spline_coef_;
71+
72+
Eigen::Index get_index(double key) const;
8973
};
9074

9175
#endif // INTERPOLATION__SPLINE_INTERPOLATION_HPP_

common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#define INTERPOLATION__SPLINE_INTERPOLATION_POINTS_2D_HPP_
1717

1818
#include "interpolation/spline_interpolation.hpp"
19+
#include "tier4_autoware_utils/geometry/geometry.hpp"
20+
21+
#include <geometry_msgs/msg/point.hpp>
1922

2023
#include <vector>
2124

common/interpolation/package.xml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<license>Apache License 2.0</license>
1010
<buildtool_depend>ament_cmake_auto</buildtool_depend>
1111

12+
<depend>eigen</depend>
1213
<depend>tier4_autoware_utils</depend>
1314

1415
<test_depend>ament_lint_auto</test_depend>

common/interpolation/src/spline_interpolation.cpp

+89-121
Original file line numberDiff line numberDiff line change
@@ -14,70 +14,43 @@
1414

1515
#include "interpolation/spline_interpolation.hpp"
1616

17-
#include <vector>
17+
#include "interpolation/interpolation_utils.hpp"
1818

19-
namespace
20-
{
21-
// solve Ax = d
22-
// where A is tridiagonal matrix
23-
// [b_0 c_0 ... ]
24-
// [a_0 b_1 c_1 ... O ]
25-
// A = [ ... ]
26-
// [ O ... a_N-3 b_N-2 c_N-2]
27-
// [ ... a_N-2 b_N-1]
28-
struct TDMACoef
19+
#include <algorithm>
20+
21+
Eigen::VectorXd solve_tridiagonal_matrix_algorithm(
22+
const Eigen::Ref<const Eigen::VectorXd> & a, const Eigen::Ref<const Eigen::VectorXd> & b,
23+
const Eigen::Ref<const Eigen::VectorXd> & c, const Eigen::Ref<const Eigen::VectorXd> & d)
2924
{
30-
explicit TDMACoef(const size_t num_row)
31-
{
32-
a.resize(num_row - 1);
33-
b.resize(num_row);
34-
c.resize(num_row - 1);
35-
d.resize(num_row);
25+
auto n = d.size();
26+
27+
if (n == 1) {
28+
return d.array() / b.array();
3629
}
3730

38-
std::vector<double> a;
39-
std::vector<double> b;
40-
std::vector<double> c;
41-
std::vector<double> d;
42-
};
31+
Eigen::VectorXd c_prime = Eigen::VectorXd::Zero(n);
32+
Eigen::VectorXd d_prime = Eigen::VectorXd::Zero(n);
33+
Eigen::VectorXd x = Eigen::VectorXd::Zero(n);
4334

44-
inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
45-
{
46-
const auto & a = tdma_coef.a;
47-
const auto & b = tdma_coef.b;
48-
const auto & c = tdma_coef.c;
49-
const auto & d = tdma_coef.d;
50-
51-
const size_t num_row = b.size();
52-
53-
std::vector<double> x(num_row);
54-
if (num_row != 1) {
55-
// calculate p and q
56-
std::vector<double> p;
57-
std::vector<double> q;
58-
p.push_back(-c[0] / b[0]);
59-
q.push_back(d[0] / b[0]);
60-
61-
for (size_t i = 1; i < num_row; ++i) {
62-
const double den = b[i] + a[i - 1] * p[i - 1];
63-
p.push_back(-c[i - 1] / den);
64-
q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
65-
}
66-
67-
// calculate solution
68-
x[num_row - 1] = q[num_row - 1];
69-
70-
for (size_t i = 1; i < num_row; ++i) {
71-
const size_t j = num_row - 1 - i;
72-
x[j] = p[j] * x[j + 1] + q[j];
73-
}
74-
} else {
75-
x.push_back(d[0] / b[0]);
35+
// Forward sweep
36+
c_prime(0) = c(0) / b(0);
37+
d_prime(0) = d(0) / b(0);
38+
39+
for (auto i = 1; i < n; i++) {
40+
double m = 1.0 / (b(i) - a(i - 1) * c_prime(i - 1));
41+
c_prime(i) = i < n - 1 ? c(i) * m : 0;
42+
d_prime(i) = (d(i) - a(i - 1) * d_prime(i - 1)) * m;
43+
}
44+
45+
// Back substitution
46+
x(n - 1) = d_prime(n - 1);
47+
48+
for (auto i = n - 2; i >= 0; i--) {
49+
x(i) = d_prime(i) - c_prime(i) * x(i + 1);
7650
}
7751

7852
return x;
7953
}
80-
} // namespace
8154

8255
namespace interpolation
8356
{
@@ -101,73 +74,74 @@ void SplineInterpolation::calcSplineCoefficients(
10174
// throw exceptions for invalid arguments
10275
interpolation_utils::validateKeysAndValues(base_keys, base_values);
10376

104-
const size_t num_base = base_keys.size(); // N+1
105-
106-
std::vector<double> diff_keys; // N
107-
std::vector<double> diff_values; // N
108-
for (size_t i = 0; i < num_base - 1; ++i) {
109-
diff_keys.push_back(base_keys.at(i + 1) - base_keys.at(i));
110-
diff_values.push_back(base_values.at(i + 1) - base_values.at(i));
111-
}
112-
113-
std::vector<double> v = {0.0};
114-
if (num_base > 2) {
115-
// solve tridiagonal matrix algorithm
116-
TDMACoef tdma_coef(num_base - 2); // N-1
117-
118-
for (size_t i = 0; i < num_base - 2; ++i) {
119-
tdma_coef.b[i] = 2 * (diff_keys[i] + diff_keys[i + 1]);
120-
if (i != num_base - 3) {
121-
tdma_coef.a[i] = diff_keys[i + 1];
122-
tdma_coef.c[i] = diff_keys[i + 1];
123-
}
124-
tdma_coef.d[i] =
125-
6.0 * (diff_values[i + 1] / diff_keys[i + 1] - diff_values[i] / diff_keys[i]);
126-
}
127-
128-
const std::vector<double> tdma_res = solveTridiagonalMatrixAlgorithm(tdma_coef);
129-
130-
// calculate v
131-
v.insert(v.end(), tdma_res.begin(), tdma_res.end());
132-
}
133-
v.push_back(0.0);
134-
135-
// calculate a, b, c, d of spline coefficients
136-
multi_spline_coef_ = interpolation::MultiSplineCoef{num_base - 1}; // N
137-
for (size_t i = 0; i < num_base - 1; ++i) {
138-
multi_spline_coef_.a[i] = (v[i + 1] - v[i]) / 6.0 / diff_keys[i];
139-
multi_spline_coef_.b[i] = v[i] / 2.0;
140-
multi_spline_coef_.c[i] =
141-
diff_values[i] / diff_keys[i] - diff_keys[i] * (2 * v[i] + v[i + 1]) / 6.0;
142-
multi_spline_coef_.d[i] = base_values[i];
77+
Eigen::VectorXd x = Eigen::Map<const Eigen::VectorXd>(
78+
base_keys.data(), static_cast<Eigen::Index>(base_keys.size()));
79+
Eigen::VectorXd y = Eigen::Map<const Eigen::VectorXd>(
80+
base_values.data(), static_cast<Eigen::Index>(base_values.size()));
81+
82+
const auto n = x.size();
83+
84+
if (n == 2) {
85+
a_ = Eigen::VectorXd::Zero(1);
86+
b_ = Eigen::VectorXd::Zero(1);
87+
c_ = Eigen::VectorXd::Zero(1);
88+
d_ = Eigen::VectorXd::Zero(1);
89+
c_[0] = (y[1] - y[0]) / (x[1] - x[0]);
90+
d_[0] = y[0];
91+
base_keys_ = base_keys;
92+
return;
14393
}
14494

95+
// Create Tridiagonal matrix
96+
Eigen::VectorXd v(n);
97+
Eigen::VectorXd h = x.segment(1, n - 1) - x.segment(0, n - 1);
98+
Eigen::VectorXd a = h.segment(1, n - 3);
99+
Eigen::VectorXd b = 2 * (h.segment(0, n - 2) + h.segment(1, n - 2));
100+
Eigen::VectorXd c = h.segment(1, n - 3);
101+
Eigen::VectorXd y_diff = y.segment(1, n - 1) - y.segment(0, n - 1);
102+
Eigen::VectorXd d = 6 * (y_diff.segment(1, n - 2).array() / h.tail(n - 2).array() -
103+
y_diff.segment(0, n - 2).array() / h.head(n - 2).array());
104+
105+
// Solve tridiagonal matrix
106+
v.segment(1, n - 2) = solve_tridiagonal_matrix_algorithm(a, b, c, d);
107+
v[0] = 0;
108+
v[n - 1] = 0;
109+
110+
// Calculate spline coefficients
111+
a_ = (v.tail(n - 1) - v.head(n - 1)).array() / 6.0 / (x.tail(n - 1) - x.head(n - 1)).array();
112+
b_ = v.segment(0, n - 1) / 2.0;
113+
c_ = (y.tail(n - 1) - y.head(n - 1)).array() / (x.tail(n - 1) - x.head(n - 1)).array() -
114+
(x.tail(n - 1) - x.head(n - 1)).array() *
115+
(2 * v.segment(0, n - 1).array() + v.segment(1, n - 1).array()) / 6.0;
116+
d_ = y.head(n - 1);
145117
base_keys_ = base_keys;
146118
}
147119

120+
Eigen::Index SplineInterpolation::get_index(double key) const
121+
{
122+
auto it = std::lower_bound(base_keys_.begin(), base_keys_.end(), key);
123+
return std::clamp(
124+
static_cast<int>(std::distance(base_keys_.begin(), it)) - 1, 0,
125+
static_cast<int>(base_keys_.size()) - 2);
126+
}
127+
148128
std::vector<double> SplineInterpolation::getSplineInterpolatedValues(
149129
const std::vector<double> & query_keys) const
150130
{
151131
// throw exceptions for invalid arguments
152132
interpolation_utils::validateKeys(base_keys_, query_keys);
153133

154-
const auto & a = multi_spline_coef_.a;
155-
const auto & b = multi_spline_coef_.b;
156-
const auto & c = multi_spline_coef_.c;
157-
const auto & d = multi_spline_coef_.d;
134+
std::vector<double> interpolated_values;
135+
interpolated_values.reserve(query_keys.size());
158136

159-
std::vector<double> res;
160-
size_t j = 0;
161-
for (const auto & query_key : query_keys) {
162-
while (base_keys_.at(j + 1) < query_key) {
163-
++j;
164-
}
165-
166-
const double ds = query_key - base_keys_.at(j);
167-
res.push_back(d.at(j) + (c.at(j) + (b.at(j) + a.at(j) * ds) * ds) * ds);
137+
for (const auto & key : query_keys) {
138+
const auto idx = get_index(key);
139+
const auto dx = key - base_keys_[idx];
140+
interpolated_values.emplace_back(
141+
a_[idx] * dx * dx * dx + b_[idx] * dx * dx + c_[idx] * dx + d_[idx]);
168142
}
169143

170-
return res;
144+
return interpolated_values;
171145
}
172146

173147
std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
@@ -176,20 +150,14 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
176150
// throw exceptions for invalid arguments
177151
interpolation_utils::validateKeys(base_keys_, query_keys);
178152

179-
const auto & a = multi_spline_coef_.a;
180-
const auto & b = multi_spline_coef_.b;
181-
const auto & c = multi_spline_coef_.c;
182-
183-
std::vector<double> res;
184-
size_t j = 0;
185-
for (const auto & query_key : query_keys) {
186-
while (base_keys_.at(j + 1) < query_key) {
187-
++j;
188-
}
153+
std::vector<double> interpolated_diff_values;
154+
interpolated_diff_values.reserve(query_keys.size());
189155

190-
const double ds = query_key - base_keys_.at(j);
191-
res.push_back(c.at(j) + (2.0 * b.at(j) + 3.0 * a.at(j) * ds) * ds);
156+
for (const auto & key : query_keys) {
157+
const auto idx = get_index(key);
158+
const auto dx = key - base_keys_[idx];
159+
interpolated_diff_values.emplace_back(3 * a_[idx] * dx * dx + 2 * b_[idx] * dx + c_[idx]);
192160
}
193161

194-
return res;
162+
return interpolated_diff_values;
195163
}

0 commit comments

Comments
 (0)