Skip to content

Commit 018deed

Browse files
committed
[WIP] Integrate with POI to improve UX
1 parent 821e14e commit 018deed

8 files changed

+1206
-10
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1414
MathOptSetDistances = "3b969827-a86c-476c-9527-bb6f1a8fbad5"
15+
ParametricOptInterface = "0ce4ce61-57bf-432b-a095-efac525d185e"
1516
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1617

1718
[compat]

src/DiffOpt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ import LazyArrays
1212
import LinearAlgebra
1313
import MathOptInterface as MOI
1414
import MathOptSetDistances as MOSD
15+
import ParametricOptInterface as POI
1516
import SparseArrays
1617

1718
include("utils.jl")
1819
include("product_of_sets.jl")
1920
include("diff_opt.jl")
2021
include("moi_wrapper.jl")
2122
include("jump_moi_overloads.jl")
23+
include("parameters.jl")
2224

2325
include("copy_dual.jl")
2426
include("bridges.jl")
@@ -38,6 +40,16 @@ function add_all_model_constructors(model)
3840
return
3941
end
4042

43+
function add_conic_model_constructor(model)
44+
add_model_constructor(model, ConicProgram.Model)
45+
return
46+
end
47+
48+
function add_quadratic_model_constructor(model)
49+
add_model_constructor(model, QuadraticProgram.Model)
50+
return
51+
end
52+
4153
export diff_optimizer
4254

4355
end # module

src/diff_opt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,17 @@ The output solution differentials can be queried with the attribute
5858
"""
5959
function forward_differentiate! end
6060

61+
"""
62+
empty_input_sensitivities!(model::MOI.ModelLike)
63+
64+
Empty the input sensitivities of the model.
65+
Sets to zero all the sensitivities set by the user with method such as:
66+
- `MOI.set(model, DiffOpt.ReverseVariablePrimal(), variable_index, value)`
67+
- `MOI.set(model, DiffOpt.ForwardObjectiveFunction(), expression)`
68+
- `MOI.set(model, DiffOpt.ForwardConstraintFunction(), index, expression)`
69+
"""
70+
function empty_input_sensitivities! end
71+
6172
"""
6273
ForwardObjectiveFunction <: MOI.AbstractModelAttribute
6374

src/jump_moi_overloads.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ function forward_differentiate!(model::JuMP.Model)
307307
return forward_differentiate!(JuMP.backend(model))
308308
end
309309

310+
function empty_input_sensitivities!(model::JuMP.Model)
311+
empty_input_sensitivities!(JuMP.backend(model))
312+
return
313+
end
314+
310315
# MOI.Utilities
311316

312317
function reverse_differentiate!(model::MOI.Utilities.CachingOptimizer)
@@ -317,6 +322,11 @@ function forward_differentiate!(model::MOI.Utilities.CachingOptimizer)
317322
return forward_differentiate!(model.optimizer)
318323
end
319324

325+
function empty_input_sensitivities!(model::MOI.Utilities.CachingOptimizer)
326+
empty_input_sensitivities!(model.optimizer)
327+
return
328+
end
329+
320330
# MOIB
321331

322332
function reverse_differentiate!(model::MOI.Bridges.AbstractBridgeOptimizer)
@@ -326,3 +336,8 @@ end
326336
function forward_differentiate!(model::MOI.Bridges.AbstractBridgeOptimizer)
327337
return forward_differentiate!(model.model)
328338
end
339+
340+
function empty_input_sensitivities!(model::MOI.Bridges.AbstractBridgeOptimizer)
341+
empty_input_sensitivities!(model.model)
342+
return
343+
end

src/moi_wrapper.jl

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,37 @@
44
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
55

66
"""
7-
diff_optimizer(optimizer_constructor)::Optimizer
7+
DiffMethod
8+
9+
An enum to define the differentiation method.
10+
11+
## Values
12+
13+
Possible values are:
14+
15+
* [`DIFF_AUTOMATIC`]: Automatic differentiation: tries all differentiation methods and chooses the first that works. This can be slower than manually choosing a method.
16+
* [`DIFF_CONIC`]: Conic optimization based differentiation: works for conic programs or programs that can be transformed into conic programs by JuMP.
17+
* [`DIFF_QUADRATIC`]: Quadratic optimization based differentiation: works for quadratic programs or programs that can be transformed into conic programs by JuMP.
18+
"""
19+
@enum(DiffMethod, DIFF_AUTOMATIC, DIFF_CONIC, DIFF_QUADRATIC)
20+
21+
@doc(
22+
"Automatic differentiation: tries all differentiation methods and chooses the first that works. This can be slower than manually choosing a method.",
23+
DIFF_AUTOMATIC
24+
)
25+
26+
@doc(
27+
"Conic optimization based differentiation: works for conic programs or programs that can be transformed into conic programs by JuMP.",
28+
DIFF_CONIC
29+
)
30+
31+
@doc(
32+
"Quadratic optimization based differentiation: works for quadratic programs or programs that can be transformed into conic programs by JuMP.",
33+
DIFF_QUADRATIC
34+
)
35+
36+
"""
37+
diff_optimizer(optimizer_constructor)
838
939
Creates a `DiffOpt.Optimizer`, which is an MOI layer with an internal optimizer
1040
and other utility methods. Results (primal, dual and slack values) are obtained
@@ -21,19 +51,33 @@ julia> x = model.add_variable(model)
2151
julia> model.add_constraint(model, ...)
2252
```
2353
"""
24-
function diff_optimizer(optimizer_constructor)::Optimizer
54+
function diff_optimizer(
55+
optimizer_constructor;
56+
method = DIFF_AUTOMATIC,
57+
with_parametric_opt_interface::Bool = false,
58+
with_bridge_type = Float64,
59+
with_cache::Bool = true,
60+
)
2561
optimizer =
26-
MOI.instantiate(optimizer_constructor; with_bridge_type = Float64)
62+
MOI.instantiate(optimizer_constructor; with_bridge_type = with_bridge_type)
2763
# When we do `MOI.copy_to(diff, optimizer)` we need to efficiently `MOI.get`
2864
# the model information from `optimizer`. However, 1) `optimizer` may not
2965
# implement some getters or it may be inefficient and 2) the getters may be
3066
# unimplemented or inefficient through some bridges.
3167
# For this reason we add a cache layer, the same cache JuMP adds.
32-
caching_opt = MOI.Utilities.CachingOptimizer(
33-
MOI.Utilities.UniversalFallback(MOI.Utilities.Model{Float64}()),
34-
optimizer,
35-
)
36-
return Optimizer(caching_opt)
68+
caching_opt = if with_cache
69+
MOI.Utilities.CachingOptimizer(
70+
MOI.Utilities.UniversalFallback(MOI.Utilities.Model{Float64}()),
71+
optimizer,
72+
)
73+
else
74+
optimizer
75+
end
76+
if with_parametric_opt_interface
77+
return POI.Optimizer(Optimizer(caching_opt, method = method))
78+
else
79+
return Optimizer(caching_opt, method = method)
80+
end
3781
end
3882

3983
mutable struct Optimizer{OT<:MOI.ModelLike} <: MOI.AbstractOptimizer
@@ -49,10 +93,18 @@ mutable struct Optimizer{OT<:MOI.ModelLike} <: MOI.AbstractOptimizer
4993
# sensitivity input cache using MOI like sparse format
5094
input_cache::InputCache
5195

52-
function Optimizer(optimizer::OT) where {OT<:MOI.ModelLike}
96+
function Optimizer(optimizer::OT; method = DIFF_AUTOMATIC) where {OT<:MOI.ModelLike}
5397
output =
5498
new{OT}(optimizer, Any[], nothing, nothing, nothing, InputCache())
55-
add_all_model_constructors(output)
99+
if method == DIFF_CONIC
100+
add_conic_model_constructor(output)
101+
elseif method == DIFF_QUADRATIC
102+
add_quadratic_model_constructor(output)
103+
elseif method == DIFF_AUTOMATIC
104+
add_all_model_constructors(output)
105+
else
106+
add_model_constructor(output, method)
107+
end
56108
return output
57109
end
58110
end
@@ -552,6 +604,11 @@ function forward_differentiate!(model::Optimizer)
552604
return forward_differentiate!(diff)
553605
end
554606

607+
function empty_input_sensitivities!(model::Optimizer)
608+
empty!(model.input_cache)
609+
return
610+
end
611+
555612
function _instantiate_with_bridges(model_constructor)
556613
model = MOI.Bridges.LazyBridgeOptimizer(MOI.instantiate(model_constructor))
557614
# We don't add any variable bridge here because:

0 commit comments

Comments
 (0)