Skip to content

Commit c152dfe

Browse files
authored
Merge branch 'master' into ar/NonLinearProgram
2 parents e668506 + 95927fa commit c152dfe

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

src/moi_wrapper.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,9 @@ end
587587

588588
function empty_input_sensitivities!(model::Optimizer)
589589
empty!(model.input_cache)
590+
if model.diff !== nothing
591+
empty!(model.diff.model.input_cache)
592+
end
590593
return
591594
end
592595

test/parameters.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,6 @@ function test_diff_errors_POI()
671671
return
672672
end
673673

674-
675674
function test_diff_errors()
676675
model = Model(
677676
() -> DiffOpt.diff_optimizer(
@@ -719,7 +718,55 @@ function test_diff_errors()
719718
DiffOpt.ReverseConstraintFunction(),
720719
cons,
721720
)
721+
return
722+
end
723+
724+
function is_empty(cache::DiffOpt.InputCache)
725+
return isempty(cache.dx) &&
726+
isempty(cache.scalar_constraints) &&
727+
isempty(cache.vector_constraints) &&
728+
cache.objective === nothing
729+
end
730+
731+
# Credit to @klamike
732+
function test_empty_cache()
733+
m = Model(
734+
() -> DiffOpt.diff_optimizer(
735+
HiGHS.Optimizer;
736+
with_parametric_opt_interface = true,
737+
),
738+
)
739+
@variable(m, x)
740+
@variable(m, p Parameter(1.0))
741+
@variable(m, q Parameter(2.0))
742+
@constraint(m, x p)
743+
@constraint(m, x q)
744+
@objective(m, Min, x)
745+
optimize!(m)
746+
@assert is_solved_and_feasible(m)
747+
748+
function get_sensitivity(m, xᵢ, pᵢ)
749+
DiffOpt.empty_input_sensitivities!(m)
750+
@test is_empty(unsafe_backend(m).optimizer.input_cache)
751+
if !isnothing(unsafe_backend(m).optimizer.diff) &&
752+
!isnothing(unsafe_backend(m).optimizer.diff.model.input_cache)
753+
@test is_empty(unsafe_backend(m).optimizer.diff.model.input_cache)
754+
end
755+
MOI.set(
756+
m,
757+
DiffOpt.ForwardConstraintSet(),
758+
ParameterRef(pᵢ),
759+
Parameter(1.0),
760+
)
761+
DiffOpt.forward_differentiate!(m)
762+
return MOI.get(m, DiffOpt.ForwardVariablePrimal(), xᵢ)
763+
end
722764

765+
sp1 = get_sensitivity(m, x, p)
766+
sp2 = get_sensitivity(m, x, q)
767+
sp3 = get_sensitivity(m, x, p)
768+
@test sp1 sp3
769+
@test sp2 sp3
723770
return
724771
end
725772

0 commit comments

Comments
 (0)