14
14
15
15
#include " interpolation/spline_interpolation.hpp"
16
16
17
- #include < vector >
17
+ #include " interpolation/interpolation_utils.hpp "
18
18
19
- namespace
20
- {
21
- // solve Ax = d
22
- // where A is tridiagonal matrix
23
- // [b_0 c_0 ... ]
24
- // [a_0 b_1 c_1 ... O ]
25
- // A = [ ... ]
26
- // [ O ... a_N-3 b_N-2 c_N-2]
27
- // [ ... a_N-2 b_N-1]
28
- struct TDMACoef
19
+ #include < algorithm>
20
+
21
+ Eigen::VectorXd solve_tridiagonal_matrix_algorithm (
22
+ const Eigen::Ref<const Eigen::VectorXd> & a, const Eigen::Ref<const Eigen::VectorXd> & b,
23
+ const Eigen::Ref<const Eigen::VectorXd> & c, const Eigen::Ref<const Eigen::VectorXd> & d)
29
24
{
30
- explicit TDMACoef (const size_t num_row)
31
- {
32
- a.resize (num_row - 1 );
33
- b.resize (num_row);
34
- c.resize (num_row - 1 );
35
- d.resize (num_row);
25
+ auto n = d.size ();
26
+
27
+ if (n == 1 ) {
28
+ return d.array () / b.array ();
36
29
}
37
30
38
- std::vector<double > a;
39
- std::vector<double > b;
40
- std::vector<double > c;
41
- std::vector<double > d;
42
- };
31
+ Eigen::VectorXd c_prime = Eigen::VectorXd::Zero (n);
32
+ Eigen::VectorXd d_prime = Eigen::VectorXd::Zero (n);
33
+ Eigen::VectorXd x = Eigen::VectorXd::Zero (n);
43
34
44
- inline std::vector<double > solveTridiagonalMatrixAlgorithm (const TDMACoef & tdma_coef)
45
- {
46
- const auto & a = tdma_coef.a ;
47
- const auto & b = tdma_coef.b ;
48
- const auto & c = tdma_coef.c ;
49
- const auto & d = tdma_coef.d ;
50
-
51
- const size_t num_row = b.size ();
52
-
53
- std::vector<double > x (num_row);
54
- if (num_row != 1 ) {
55
- // calculate p and q
56
- std::vector<double > p;
57
- std::vector<double > q;
58
- p.push_back (-c[0 ] / b[0 ]);
59
- q.push_back (d[0 ] / b[0 ]);
60
-
61
- for (size_t i = 1 ; i < num_row; ++i) {
62
- const double den = b[i] + a[i - 1 ] * p[i - 1 ];
63
- p.push_back (-c[i - 1 ] / den);
64
- q.push_back ((d[i] - a[i - 1 ] * q[i - 1 ]) / den);
65
- }
66
-
67
- // calculate solution
68
- x[num_row - 1 ] = q[num_row - 1 ];
69
-
70
- for (size_t i = 1 ; i < num_row; ++i) {
71
- const size_t j = num_row - 1 - i;
72
- x[j] = p[j] * x[j + 1 ] + q[j];
73
- }
74
- } else {
75
- x.push_back (d[0 ] / b[0 ]);
35
+ // Forward sweep
36
+ c_prime (0 ) = c (0 ) / b (0 );
37
+ d_prime (0 ) = d (0 ) / b (0 );
38
+
39
+ for (auto i = 1 ; i < n; i++) {
40
+ double m = 1.0 / (b (i) - a (i - 1 ) * c_prime (i - 1 ));
41
+ c_prime (i) = i < n - 1 ? c (i) * m : 0 ;
42
+ d_prime (i) = (d (i) - a (i - 1 ) * d_prime (i - 1 )) * m;
43
+ }
44
+
45
+ // Back substitution
46
+ x (n - 1 ) = d_prime (n - 1 );
47
+
48
+ for (auto i = n - 2 ; i >= 0 ; i--) {
49
+ x (i) = d_prime (i) - c_prime (i) * x (i + 1 );
76
50
}
77
51
78
52
return x;
79
53
}
80
- } // namespace
81
54
82
55
namespace interpolation
83
56
{
@@ -101,73 +74,74 @@ void SplineInterpolation::calcSplineCoefficients(
101
74
// throw exceptions for invalid arguments
102
75
interpolation_utils::validateKeysAndValues (base_keys, base_values);
103
76
104
- const size_t num_base = base_keys.size (); // N+1
105
-
106
- std::vector<double > diff_keys; // N
107
- std::vector<double > diff_values; // N
108
- for (size_t i = 0 ; i < num_base - 1 ; ++i) {
109
- diff_keys.push_back (base_keys.at (i + 1 ) - base_keys.at (i));
110
- diff_values.push_back (base_values.at (i + 1 ) - base_values.at (i));
111
- }
112
-
113
- std::vector<double > v = {0.0 };
114
- if (num_base > 2 ) {
115
- // solve tridiagonal matrix algorithm
116
- TDMACoef tdma_coef (num_base - 2 ); // N-1
117
-
118
- for (size_t i = 0 ; i < num_base - 2 ; ++i) {
119
- tdma_coef.b [i] = 2 * (diff_keys[i] + diff_keys[i + 1 ]);
120
- if (i != num_base - 3 ) {
121
- tdma_coef.a [i] = diff_keys[i + 1 ];
122
- tdma_coef.c [i] = diff_keys[i + 1 ];
123
- }
124
- tdma_coef.d [i] =
125
- 6.0 * (diff_values[i + 1 ] / diff_keys[i + 1 ] - diff_values[i] / diff_keys[i]);
126
- }
127
-
128
- const std::vector<double > tdma_res = solveTridiagonalMatrixAlgorithm (tdma_coef);
129
-
130
- // calculate v
131
- v.insert (v.end (), tdma_res.begin (), tdma_res.end ());
132
- }
133
- v.push_back (0.0 );
134
-
135
- // calculate a, b, c, d of spline coefficients
136
- multi_spline_coef_ = interpolation::MultiSplineCoef{num_base - 1 }; // N
137
- for (size_t i = 0 ; i < num_base - 1 ; ++i) {
138
- multi_spline_coef_.a [i] = (v[i + 1 ] - v[i]) / 6.0 / diff_keys[i];
139
- multi_spline_coef_.b [i] = v[i] / 2.0 ;
140
- multi_spline_coef_.c [i] =
141
- diff_values[i] / diff_keys[i] - diff_keys[i] * (2 * v[i] + v[i + 1 ]) / 6.0 ;
142
- multi_spline_coef_.d [i] = base_values[i];
77
+ Eigen::VectorXd x = Eigen::Map<const Eigen::VectorXd>(
78
+ base_keys.data (), static_cast <Eigen::Index>(base_keys.size ()));
79
+ Eigen::VectorXd y = Eigen::Map<const Eigen::VectorXd>(
80
+ base_values.data (), static_cast <Eigen::Index>(base_values.size ()));
81
+
82
+ const auto n = x.size ();
83
+
84
+ if (n == 2 ) {
85
+ a_ = Eigen::VectorXd::Zero (1 );
86
+ b_ = Eigen::VectorXd::Zero (1 );
87
+ c_ = Eigen::VectorXd::Zero (1 );
88
+ d_ = Eigen::VectorXd::Zero (1 );
89
+ c_[0 ] = (y[1 ] - y[0 ]) / (x[1 ] - x[0 ]);
90
+ d_[0 ] = y[0 ];
91
+ base_keys_ = base_keys;
92
+ return ;
143
93
}
144
94
95
+ // Create Tridiagonal matrix
96
+ Eigen::VectorXd v (n);
97
+ Eigen::VectorXd h = x.segment (1 , n - 1 ) - x.segment (0 , n - 1 );
98
+ Eigen::VectorXd a = h.segment (1 , n - 3 );
99
+ Eigen::VectorXd b = 2 * (h.segment (0 , n - 2 ) + h.segment (1 , n - 2 ));
100
+ Eigen::VectorXd c = h.segment (1 , n - 3 );
101
+ Eigen::VectorXd y_diff = y.segment (1 , n - 1 ) - y.segment (0 , n - 1 );
102
+ Eigen::VectorXd d = 6 * (y_diff.segment (1 , n - 2 ).array () / h.tail (n - 2 ).array () -
103
+ y_diff.segment (0 , n - 2 ).array () / h.head (n - 2 ).array ());
104
+
105
+ // Solve tridiagonal matrix
106
+ v.segment (1 , n - 2 ) = solve_tridiagonal_matrix_algorithm (a, b, c, d);
107
+ v[0 ] = 0 ;
108
+ v[n - 1 ] = 0 ;
109
+
110
+ // Calculate spline coefficients
111
+ a_ = (v.tail (n - 1 ) - v.head (n - 1 )).array () / 6.0 / (x.tail (n - 1 ) - x.head (n - 1 )).array ();
112
+ b_ = v.segment (0 , n - 1 ) / 2.0 ;
113
+ c_ = (y.tail (n - 1 ) - y.head (n - 1 )).array () / (x.tail (n - 1 ) - x.head (n - 1 )).array () -
114
+ (x.tail (n - 1 ) - x.head (n - 1 )).array () *
115
+ (2 * v.segment (0 , n - 1 ).array () + v.segment (1 , n - 1 ).array ()) / 6.0 ;
116
+ d_ = y.head (n - 1 );
145
117
base_keys_ = base_keys;
146
118
}
147
119
120
+ Eigen::Index SplineInterpolation::get_index (double key) const
121
+ {
122
+ auto it = std::lower_bound (base_keys_.begin (), base_keys_.end (), key);
123
+ return std::clamp (
124
+ static_cast <int >(std::distance (base_keys_.begin (), it)) - 1 , 0 ,
125
+ static_cast <int >(base_keys_.size ()) - 2 );
126
+ }
127
+
148
128
std::vector<double > SplineInterpolation::getSplineInterpolatedValues (
149
129
const std::vector<double > & query_keys) const
150
130
{
151
131
// throw exceptions for invalid arguments
152
132
interpolation_utils::validateKeys (base_keys_, query_keys);
153
133
154
- const auto & a = multi_spline_coef_.a ;
155
- const auto & b = multi_spline_coef_.b ;
156
- const auto & c = multi_spline_coef_.c ;
157
- const auto & d = multi_spline_coef_.d ;
134
+ std::vector<double > interpolated_values;
135
+ interpolated_values.reserve (query_keys.size ());
158
136
159
- std::vector<double > res;
160
- size_t j = 0 ;
161
- for (const auto & query_key : query_keys) {
162
- while (base_keys_.at (j + 1 ) < query_key) {
163
- ++j;
164
- }
165
-
166
- const double ds = query_key - base_keys_.at (j);
167
- res.push_back (d.at (j) + (c.at (j) + (b.at (j) + a.at (j) * ds) * ds) * ds);
137
+ for (const auto & key : query_keys) {
138
+ const auto idx = get_index (key);
139
+ const auto dx = key - base_keys_[idx];
140
+ interpolated_values.emplace_back (
141
+ a_[idx] * dx * dx * dx + b_[idx] * dx * dx + c_[idx] * dx + d_[idx]);
168
142
}
169
143
170
- return res ;
144
+ return interpolated_values ;
171
145
}
172
146
173
147
std::vector<double > SplineInterpolation::getSplineInterpolatedDiffValues (
@@ -176,20 +150,14 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
176
150
// throw exceptions for invalid arguments
177
151
interpolation_utils::validateKeys (base_keys_, query_keys);
178
152
179
- const auto & a = multi_spline_coef_.a ;
180
- const auto & b = multi_spline_coef_.b ;
181
- const auto & c = multi_spline_coef_.c ;
182
-
183
- std::vector<double > res;
184
- size_t j = 0 ;
185
- for (const auto & query_key : query_keys) {
186
- while (base_keys_.at (j + 1 ) < query_key) {
187
- ++j;
188
- }
153
+ std::vector<double > interpolated_diff_values;
154
+ interpolated_diff_values.reserve (query_keys.size ());
189
155
190
- const double ds = query_key - base_keys_.at (j);
191
- res.push_back (c.at (j) + (2.0 * b.at (j) + 3.0 * a.at (j) * ds) * ds);
156
+ for (const auto & key : query_keys) {
157
+ const auto idx = get_index (key);
158
+ const auto dx = key - base_keys_[idx];
159
+ interpolated_diff_values.emplace_back (3 * a_[idx] * dx * dx + 2 * b_[idx] * dx + c_[idx]);
192
160
}
193
161
194
- return res ;
162
+ return interpolated_diff_values ;
195
163
}
0 commit comments