Skip to content

Commit 90b5b4e

Browse files
fix
1 parent ee6ac13 commit 90b5b4e

File tree

5 files changed

+96
-43
lines changed

5 files changed

+96
-43
lines changed

common/autoware_osqp_interface/design/osqp_interface-design.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ The interface can be used in several ways:
5454

5555
```cpp
5656
std::tuple<std::vector<double>, std::vector<double>> result = osqp_interface.optimize();
57-
std::vector<double> param = std::get<0>(result);
57+
std::vector<double> param = result.primal_solution;
5858
double x_0 = param[0];
5959
double x_1 = param[1];
6060
```

common/autoware_osqp_interface/include/autoware/osqp_interface/csc_matrix_conv.hpp

+38
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,44 @@ struct OSQP_INTERFACE_PUBLIC CSC_Matrix
3333
std::vector<c_int> m_row_idxs;
3434
/// Vector of 'val' indices where each column starts. Ex: [0, 2, 4] (Eigen: 'outer')
3535
std::vector<c_int> m_col_idxs;
36+
37+
friend std::ostream & operator<<(std::ostream & os, const CSC_Matrix & matrix)
38+
{
39+
os << "CSC_Matrix: {\n";
40+
os << "\tm_vals: [";
41+
42+
// Iterator-based loop for m_vals
43+
for (auto it = std::begin(matrix.m_vals); it != std::end(matrix.m_vals); ++it) {
44+
os << *it; // Print the current element (dereference iterator)
45+
if (std::next(it) != std::end(matrix.m_vals)) { // Check if not the last element
46+
os << ", ";
47+
}
48+
}
49+
os << "],\n";
50+
51+
os << "\tm_row_idxs: [";
52+
// Iterator-based loop for m_row_idxs
53+
for (auto it = std::begin(matrix.m_row_idxs); it != std::end(matrix.m_row_idxs); ++it) {
54+
os << *it;
55+
if (std::next(it) != std::end(matrix.m_row_idxs)) {
56+
os << ", ";
57+
}
58+
}
59+
os << "],\n";
60+
61+
os << "\tm_col_idxs: [";
62+
// Iterator-based loop for m_col_idxs
63+
for (auto it = std::begin(matrix.m_col_idxs); it != std::end(matrix.m_col_idxs); ++it) {
64+
os << *it;
65+
if (std::next(it) != std::end(matrix.m_col_idxs)) {
66+
os << ", ";
67+
}
68+
}
69+
os << "]\n";
70+
71+
os << "}\n";
72+
return os;
73+
}
3674
};
3775

3876
/// \brief Calculate CSC matrix from Eigen matrix

common/autoware_osqp_interface/include/autoware/osqp_interface/osqp_interface.hpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ namespace autoware::osqp_interface
3232
{
3333
constexpr c_float INF = 1e30;
3434

35+
36+
struct OSQPResult
37+
{
38+
std::vector<double> primal_solution;
39+
std::vector<double> lagrange_multipliers;
40+
int polish_status;
41+
int solution_status;
42+
int iteration_status;
43+
int exit_flag;
44+
};
45+
3546
/**
3647
* Implementation of a native C++ interface for the OSQP solver.
3748
*
@@ -52,7 +63,7 @@ class OSQP_INTERFACE_PUBLIC OSQPInterface
5263
int64_t m_exitflag;
5364

5465
// Runs the solver on the stored problem.
55-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t> solve();
66+
OSQPResult solve();
5667

5768
static void OSQPWorkspaceDeleter(OSQPWorkspace * ptr) noexcept;
5869

@@ -93,10 +104,10 @@ class OSQP_INTERFACE_PUBLIC OSQPInterface
93104
/// \details std::tuple<std::vector<double>, std::vector<double>> result;
94105
/// \details result = osqp_interface.optimize();
95106
/// \details 4. Access the optimized parameters.
96-
/// \details std::vector<float> param = std::get<0>(result);
107+
/// \details std::vector<float> param = result.primal_solution;
97108
/// \details double x_0 = param[0];
98109
/// \details double x_1 = param[1];
99-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t> optimize();
110+
OSQPResult optimize();
100111

101112
/// \brief Solves convex quadratic programs (QPs) using the OSQP solver.
102113
/// \return The function returns a tuple containing the solution as two float vectors.
@@ -111,10 +122,10 @@ class OSQP_INTERFACE_PUBLIC OSQPInterface
111122
/// \details std::tuple<std::vector<double>, std::vector<double>> result;
112123
/// \details result = osqp_interface.optimize(P, A, q, l, u, 1e-6);
113124
/// \details 4. Access the optimized parameters.
114-
/// \details std::vector<float> param = std::get<0>(result);
125+
/// \details std::vector<float> param = result.primal_solution;
115126
/// \details double x_0 = param[0];
116127
/// \details double x_1 = param[1];
117-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t> optimize(
128+
OSQPResult optimize(
118129
const Eigen::MatrixXd & P, const Eigen::MatrixXd & A, const std::vector<double> & q,
119130
const std::vector<double> & l, const std::vector<double> & u);
120131

common/autoware_osqp_interface/src/osqp_interface.cpp

+15-10
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,12 @@ int64_t OSQPInterface::initializeProblem(
359359
return m_exitflag;
360360
}
361361

362-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t>
362+
OSQPResult
363363
OSQPInterface::solve()
364364
{
365365
// Solve Problem
366-
osqp_solve(m_work.get());
367-
366+
int32_t exit_flag = osqp_solve(m_work.get());
367+
368368
/********************
369369
* EXTRACT SOLUTION
370370
********************/
@@ -378,24 +378,29 @@ OSQPInterface::solve()
378378
int64_t status_iteration = m_work->info->iter;
379379

380380
// Result tuple
381-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t> result =
382-
std::make_tuple(
383-
sol_primal, sol_lagrange_multiplier, status_polish, status_solution, status_iteration);
381+
OSQPResult result;
382+
383+
result.primal_solution = sol_primal;
384+
result.lagrange_multipliers = sol_lagrange_multiplier;
385+
result.polish_status = status_polish;
386+
result.solution_status = status_solution;
387+
result.iteration_status = status_iteration;
388+
result.exit_flag = exit_flag;
384389

385390
m_latest_work_info = *(m_work->info);
386391

387392
return result;
388393
}
389394

390-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t>
395+
OSQPResult
391396
OSQPInterface::optimize()
392397
{
393398
// Run the solver on the stored problem representation.
394-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t> result = solve();
399+
OSQPResult result = solve();
395400
return result;
396401
}
397402

398-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t>
403+
OSQPResult
399404
OSQPInterface::optimize(
400405
const Eigen::MatrixXd & P, const Eigen::MatrixXd & A, const std::vector<double> & q,
401406
const std::vector<double> & l, const std::vector<double> & u)
@@ -404,7 +409,7 @@ OSQPInterface::optimize(
404409
initializeProblem(P, A, q, l, u);
405410

406411
// Run the solver on the stored problem representation.
407-
std::tuple<std::vector<double>, std::vector<double>, int64_t, int64_t, int64_t> result = solve();
412+
OSQPResult result = solve();
408413

409414
m_work.reset();
410415
m_work_initialized = false;

common/autoware_osqp_interface/test/test_osqp_interface.cpp

+26-27
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,24 @@ TEST(TestOsqpInterface, BasicQp)
4444
using autoware::osqp_interface::calCSCMatrixTrapezoidal;
4545
using autoware::osqp_interface::CSC_Matrix;
4646

47-
auto check_result =
48-
[](const std::tuple<std::vector<double>, std::vector<double>, int, int, int> & result) {
49-
EXPECT_EQ(std::get<2>(result), 1); // polish succeeded
50-
EXPECT_EQ(std::get<3>(result), 1); // solution succeeded
51-
52-
static const auto ep = 1.0e-8;
53-
54-
const auto prime_val = std::get<0>(result);
55-
ASSERT_EQ(prime_val.size(), size_t(2));
56-
EXPECT_NEAR(prime_val[0], 0.3, ep);
57-
EXPECT_NEAR(prime_val[1], 0.7, ep);
58-
59-
const auto dual_val = std::get<1>(result);
60-
ASSERT_EQ(dual_val.size(), size_t(4));
61-
EXPECT_NEAR(dual_val[0], -2.9, ep);
62-
EXPECT_NEAR(dual_val[1], 0.0, ep);
63-
EXPECT_NEAR(dual_val[2], 0.2, ep);
64-
EXPECT_NEAR(dual_val[3], 0.0, ep);
65-
};
47+
auto check_result = [](const autoware::osqp_interface::OSQPResult & result) {
48+
EXPECT_EQ(result.polish_status, 1); // polish succeeded
49+
EXPECT_EQ(result.solution_status, 1); // solution succeeded
50+
51+
static const auto ep = 1.0e-8;
52+
53+
const auto prime_val = result.primal_solution;
54+
ASSERT_EQ(prime_val.size(), size_t(2));
55+
EXPECT_NEAR(prime_val[0], 0.3, ep);
56+
EXPECT_NEAR(prime_val[1], 0.7, ep);
57+
58+
const auto dual_val = result.lagrange_multipliers;
59+
ASSERT_EQ(dual_val.size(), size_t(4));
60+
EXPECT_NEAR(dual_val[0], -2.9, ep);
61+
EXPECT_NEAR(dual_val[1], 0.0, ep);
62+
EXPECT_NEAR(dual_val[2], 0.2, ep);
63+
EXPECT_NEAR(dual_val[3], 0.0, ep);
64+
};
6665

6766
const Eigen::MatrixXd P = (Eigen::MatrixXd(2, 2) << 4, 1, 1, 2).finished();
6867
const Eigen::MatrixXd A = (Eigen::MatrixXd(4, 2) << 1, 1, 1, 0, 0, 1, 0, 1).finished();
@@ -73,20 +72,20 @@ TEST(TestOsqpInterface, BasicQp)
7372
{
7473
// Define problem during optimization
7574
autoware::osqp_interface::OSQPInterface osqp;
76-
std::tuple<std::vector<double>, std::vector<double>, int, int, int> result =
75+
autoware::osqp_interface::OSQPResult result =
7776
osqp.optimize(P, A, q, l, u);
7877
check_result(result);
7978
}
8079

8180
{
8281
// Define problem during initialization
8382
autoware::osqp_interface::OSQPInterface osqp(P, A, q, l, u, 1e-6);
84-
std::tuple<std::vector<double>, std::vector<double>, int, int, int> result = osqp.optimize();
83+
autoware::osqp_interface::OSQPResult result = osqp.optimize();
8584
check_result(result);
8685
}
8786

8887
{
89-
std::tuple<std::vector<double>, std::vector<double>, int, int, int> result;
88+
autoware::osqp_interface::OSQPResult result;
9089
// Dummy initial problem
9190
Eigen::MatrixXd P_ini = Eigen::MatrixXd::Zero(2, 2);
9291
Eigen::MatrixXd A_ini = Eigen::MatrixXd::Zero(4, 2);
@@ -107,12 +106,12 @@ TEST(TestOsqpInterface, BasicQp)
107106
CSC_Matrix P_csc = calCSCMatrixTrapezoidal(P);
108107
CSC_Matrix A_csc = calCSCMatrix(A);
109108
autoware::osqp_interface::OSQPInterface osqp(P_csc, A_csc, q, l, u, 1e-6);
110-
std::tuple<std::vector<double>, std::vector<double>, int, int, int> result = osqp.optimize();
109+
autoware::osqp_interface::OSQPResult result = osqp.optimize();
111110
check_result(result);
112111
}
113112

114113
{
115-
std::tuple<std::vector<double>, std::vector<double>, int, int, int> result;
114+
autoware::osqp_interface::OSQPResult result;
116115
// Dummy initial problem with csc matrix
117116
CSC_Matrix P_ini_csc = calCSCMatrixTrapezoidal(Eigen::MatrixXd::Zero(2, 2));
118117
CSC_Matrix A_ini_csc = calCSCMatrix(Eigen::MatrixXd::Zero(4, 2));
@@ -132,7 +131,7 @@ TEST(TestOsqpInterface, BasicQp)
132131

133132
// add warm startup
134133
{
135-
std::tuple<std::vector<double>, std::vector<double>, int, int, int> result;
134+
autoware::osqp_interface::OSQPResult result;
136135
// Dummy initial problem with csc matrix
137136
CSC_Matrix P_ini_csc = calCSCMatrixTrapezoidal(Eigen::MatrixXd::Zero(2, 2));
138137
CSC_Matrix A_ini_csc = calCSCMatrix(Eigen::MatrixXd::Zero(4, 2));
@@ -150,8 +149,8 @@ TEST(TestOsqpInterface, BasicQp)
150149
check_result(result);
151150

152151
osqp.updateCheckTermination(1);
153-
const auto primal_val = std::get<0>(result);
154-
const auto dual_val = std::get<1>(result);
152+
const auto primal_val = result.primal_solution;
153+
const auto dual_val = result.lagrange_multipliers;
155154
for (size_t i = 0; i < primal_val.size(); ++i) {
156155
std::cerr << primal_val.at(i) << std::endl;
157156
}

0 commit comments

Comments
 (0)