Skip to content

Commit c2ce2b1

Browse files
tests ok
1 parent b1b9827 commit c2ce2b1

8 files changed

+269
-20
lines changed

Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,22 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2222
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
2323

2424
[weakdeps]
25-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2625
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
2726
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
27+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2828

2929
[extensions]
30-
CTBasePlots = "Plots"
3130
CTBaseLoadSave = ["JLD2", "JSON3"]
31+
CTBasePlots = "Plots"
3232

3333
[compat]
3434
DataStructures = "0.18"
3535
DifferentiationInterface = "0.5"
3636
DocStringExtensions = "0.9"
3737
ForwardDiff = "0.10"
3838
Interpolations = "0.15"
39+
JLD2 = "0.5"
40+
JSON3 = "1"
3941
MLStyle = "0.4"
4042
MacroTools = "0.5"
4143
Parameters = "0.12"

src/optimal_control_solution-setters.jl

+136
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,142 @@ function OptimalControlSolution(
224224
)
225225
end
226226

227+
228+
"""
229+
$(TYPEDSIGNATURES)
230+
231+
Build OCP functional solution from discrete solution (given as raw variables and multipliers plus some optional infos)
232+
"""
233+
function OptimalControlSolution(
234+
ocp::OptimalControlModel,
235+
T,
236+
X,
237+
U,
238+
v,
239+
P;
240+
objective = 0,
241+
iterations = 0,
242+
constraints_violation = 0,
243+
message = "No msg",
244+
stopping = nothing,
245+
success = nothing,
246+
constraints_types = (nothing, nothing, nothing, nothing, nothing),
247+
constraints_mult = (nothing, nothing, nothing, nothing, nothing),
248+
box_multipliers = (nothing, nothing, nothing, nothing, nothing, nothing),
249+
)
250+
dim_x = state_dimension(ocp)
251+
dim_u = control_dimension(ocp)
252+
dim_v = variable_dimension(ocp)
253+
254+
# check that time grid is strictly increasing
255+
# if not proceed with list of indexes as time grid
256+
if !issorted(T, lt = <=)
257+
println(
258+
"WARNING: time grid at solution is not strictly increasing, replacing with list of indices...",
259+
)
260+
println(T)
261+
dim_NLP_steps = length(T) - 1
262+
T = LinRange(0, dim_NLP_steps, dim_NLP_steps + 1)
263+
end
264+
265+
# variables: remove additional state for lagrange cost
266+
x = ctinterpolate(T, matrix2vec(X[:, 1:dim_x], 1))
267+
p = ctinterpolate(T[1:(end - 1)], matrix2vec(P[:, 1:dim_x], 1))
268+
u = ctinterpolate(T, matrix2vec(U[:, 1:dim_u], 1))
269+
270+
# force scalar output when dimension is 1
271+
fx = (dim_x == 1) ? deepcopy(t -> x(t)[1]) : deepcopy(t -> x(t))
272+
fu = (dim_u == 1) ? deepcopy(t -> u(t)[1]) : deepcopy(t -> u(t))
273+
fp = (dim_x == 1) ? deepcopy(t -> p(t)[1]) : deepcopy(t -> p(t))
274+
var = (dim_v == 1) ? v[1] : v
275+
276+
# misc infos
277+
infos = Dict{Symbol, Any}()
278+
infos[:constraints_violation] = constraints_violation
279+
280+
# nonlinear constraints and multipliers
281+
control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[1], 1))(t)
282+
mult_control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[1], 1))(t)
283+
state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[2], 1))(t)
284+
mult_state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[2], 1))(t)
285+
mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[3], 1))(t)
286+
mult_mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[3], 1))(t)
287+
288+
# boundary and variable constraints
289+
boundary_constraints = constraints_types[4]
290+
mult_boundary_constraints = constraints_mult[4]
291+
variable_constraints = constraints_types[5]
292+
mult_variable_constraints = constraints_mult[5]
293+
294+
# box constraints multipliers
295+
mult_state_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[1][:, 1:dim_x], 1))(t)
296+
mult_state_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[2][:, 1:dim_x], 1))
297+
mult_control_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[3][:, 1:dim_u], 1))(t)
298+
mult_control_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[4][:, 1:dim_u], 1))
299+
mult_variable_box_lower, mult_variable_box_upper = box_multipliers[5], box_multipliers[6]
300+
301+
# build and return solution
302+
if is_variable_dependent(ocp)
303+
return OptimalControlSolution(
304+
ocp;
305+
state = fx,
306+
control = fu,
307+
objective = objective,
308+
costate = fp,
309+
time_grid = T,
310+
variable = var,
311+
iterations = iterations,
312+
stopping = stopping,
313+
message = message,
314+
success = success,
315+
infos = infos,
316+
control_constraints = control_constraints,
317+
state_constraints = state_constraints,
318+
mixed_constraints = mixed_constraints,
319+
boundary_constraints = boundary_constraints,
320+
variable_constraints = variable_constraints,
321+
mult_control_constraints = mult_control_constraints,
322+
mult_state_constraints = mult_state_constraints,
323+
mult_mixed_constraints = mult_mixed_constraints,
324+
mult_boundary_constraints = mult_boundary_constraints,
325+
mult_variable_constraints = mult_variable_constraints,
326+
mult_state_box_lower = mult_state_box_lower,
327+
mult_state_box_upper = mult_state_box_upper,
328+
mult_control_box_lower = mult_control_box_lower,
329+
mult_control_box_upper = mult_control_box_upper,
330+
mult_variable_box_lower = mult_variable_box_lower,
331+
mult_variable_box_upper = mult_variable_box_upper,
332+
)
333+
else
334+
return OptimalControlSolution(
335+
ocp;
336+
state = fx,
337+
control = fu,
338+
objective = objective,
339+
costate = fp,
340+
time_grid = T,
341+
iterations = iterations,
342+
stopping = stopping,
343+
message = message,
344+
success = success,
345+
infos = infos,
346+
control_constraints = control_constraints,
347+
state_constraints = state_constraints,
348+
mixed_constraints = mixed_constraints,
349+
boundary_constraints = boundary_constraints,
350+
mult_control_constraints = mult_control_constraints,
351+
mult_state_constraints = mult_state_constraints,
352+
mult_mixed_constraints = mult_mixed_constraints,
353+
mult_boundary_constraints = mult_boundary_constraints,
354+
mult_state_box_lower = mult_state_box_lower,
355+
mult_state_box_upper = mult_state_box_upper,
356+
mult_control_box_lower = mult_control_box_lower,
357+
mult_control_box_upper = mult_control_box_upper,
358+
)
359+
end
360+
end
361+
362+
227363
# setters
228364
#state!(sol::OptimalControlSolution, state::Function) = (sol.state = state; nothing)
229365
#control!(sol::OptimalControlSolution, control::Function) = (sol.control = control; nothing)

src/optimal_control_solution-type.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,9 @@ export export_ocp_solution
8484
export import_ocp_solution
8585

8686
# placeholders (see extension CTBaseLoadSave)
87-
function export_ocp_solution end
88-
function import_ocp_solution end
87+
function export_ocp_solution(args...; kwargs...)
88+
error("Requires JLD2 and JSON3 packages")
89+
end
90+
function import_ocp_solution(args...; kwargs...)
91+
error("Requires JLD2 and JSON3 packages")
92+
end

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
4+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
5+
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
46
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
57
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
68

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Aqua
55
using CTBase
66
using DifferentiationInterface: AutoForwardDiff
77
using Plots
8+
using JLD2, JSON3
89
using Test
910

1011
# functions and types that are not exported

test/solution_test.jld2

10.4 KB
Binary file not shown.

test/solution_test.json

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
{
2+
"time_grid": [
3+
0,
4+
0.1111111111111111,
5+
0.2222222222222222,
6+
0.3333333333333333,
7+
0.4444444444444444,
8+
0.5555555555555556,
9+
0.6666666666666666,
10+
0.7777777777777778,
11+
0.8888888888888888,
12+
1
13+
],
14+
"objective": 1,
15+
"control": [
16+
0,
17+
0.2222222222222222,
18+
0.4444444444444444,
19+
0.6666666666666666,
20+
0.8888888888888888,
21+
1.1111111111111112,
22+
1.3333333333333333,
23+
1.5555555555555556,
24+
1.7777777777777777,
25+
2
26+
],
27+
"costate": [
28+
[
29+
0,
30+
-1
31+
],
32+
[
33+
0.1111111111111111,
34+
-0.8888888888888888
35+
],
36+
[
37+
0.2222222222222222,
38+
-0.7777777777777778
39+
],
40+
[
41+
0.3333333333333333,
42+
-0.6666666666666667
43+
],
44+
[
45+
0.4444444444444444,
46+
-0.5555555555555556
47+
],
48+
[
49+
0.5555555555555556,
50+
-0.4444444444444444
51+
],
52+
[
53+
0.6666666666666666,
54+
-0.33333333333333337
55+
],
56+
[
57+
0.7777777777777778,
58+
-0.2222222222222222
59+
],
60+
[
61+
0.8888888888888888,
62+
-0.11111111111111116
63+
]
64+
],
65+
"variable": null,
66+
"state": [
67+
[
68+
0,
69+
1
70+
],
71+
[
72+
0.1111111111111111,
73+
1.1111111111111112
74+
],
75+
[
76+
0.2222222222222222,
77+
1.2222222222222223
78+
],
79+
[
80+
0.3333333333333333,
81+
1.3333333333333333
82+
],
83+
[
84+
0.4444444444444444,
85+
1.4444444444444444
86+
],
87+
[
88+
0.5555555555555556,
89+
1.5555555555555556
90+
],
91+
[
92+
0.6666666666666666,
93+
1.6666666666666665
94+
],
95+
[
96+
0.7777777777777778,
97+
1.7777777777777777
98+
],
99+
[
100+
0.8888888888888888,
101+
1.8888888888888888
102+
],
103+
[
104+
1,
105+
2
106+
]
107+
]
108+
}

test/test_solution.jl

+12-16
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ function test_solution()
1212
end
1313

1414
times = range(0, 1, 10)
15-
x = t -> t
15+
x = t -> [t, t+1]
1616
u = t -> 2t
17-
p = t -> t
17+
p = t -> [t, t-1]
1818
obj = 1
1919
sol = OptimalControlSolution(
2020
ocp;
@@ -33,6 +33,12 @@ function test_solution()
3333
@test all(control_discretized(sol) .== u.(times))
3434
@test all(costate_discretized(sol) .== p.(times))
3535

36+
# test export / read solution in JSON format (NB. requires time grid in solution !)
37+
println(sol.time_grid)
38+
export_ocp_solution(sol; filename_prefix = "solution_test", format = :JSON)
39+
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test", format = :JSON)
40+
@test sol.objective == sol_reloaded.objective
41+
3642
# NonFixed ocp
3743
@def ocp begin
3844
v R, variable
@@ -45,7 +51,7 @@ function test_solution()
4551
(0.5u(t)^2) min
4652
end
4753

48-
x = t -> t
54+
x = t -> [t, t+1]
4955
u = t -> 2t
5056
obj = 1
5157
v = 1
@@ -55,19 +61,9 @@ function test_solution()
5561
@test typeof(sol) == OptimalControlSolution
5662
@test_throws UndefKeywordError OptimalControlSolution(ocp; x, u, obj)
5763

58-
5964
# test save / load solution in JLD2 format
60-
@testset verbose = true showtiming = true ":save_load :JLD2" begin
61-
export_ocp_solution(sol; filename_prefix = "solution_test")
62-
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test")
63-
@test sol.objective == sol_reloaded.objective
64-
end
65-
66-
# test export / read solution in JSON format
67-
@testset verbose = true showtiming = true ":export_read :JSON" begin
68-
export_ocp_solution(sol; filename_prefix = "solution_test", format = :JSON)
69-
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test", format = :JSON)
70-
@test sol.objective == sol_reloaded.objective
71-
end
65+
export_ocp_solution(sol; filename_prefix = "solution_test")
66+
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test")
67+
@test sol.objective == sol_reloaded.objective
7268

7369
end

0 commit comments

Comments
 (0)