Skip to content

Commit 95927fa

Browse files
Fix empty_input_sensitivities! (#273)
* Fix empty_input_sensitivities! * fix empty and create isempty * import Base.isempty * fix access * format * Update src/DiffOpt.jl Co-authored-by: Joaquim <joaquimdgarcia@gmail.com> * remove isempty cache * update format --------- Co-authored-by: Joaquim <joaquimdgarcia@gmail.com>
1 parent e4740c5 commit 95927fa

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

src/moi_wrapper.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,9 @@ end
577577

578578
function empty_input_sensitivities!(model::Optimizer)
579579
empty!(model.input_cache)
580+
if model.diff !== nothing
581+
empty!(model.diff.model.input_cache)
582+
end
580583
return
581584
end
582585

test/parameters.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,55 @@ function test_diff_errors()
671671
return
672672
end
673673

674+
function is_empty(cache::DiffOpt.InputCache)
675+
return isempty(cache.dx) &&
676+
isempty(cache.scalar_constraints) &&
677+
isempty(cache.vector_constraints) &&
678+
cache.objective === nothing
679+
end
680+
681+
# Credit to @klamike
682+
function test_empty_cache()
683+
m = Model(
684+
() -> DiffOpt.diff_optimizer(
685+
HiGHS.Optimizer;
686+
with_parametric_opt_interface = true,
687+
),
688+
)
689+
@variable(m, x)
690+
@variable(m, p Parameter(1.0))
691+
@variable(m, q Parameter(2.0))
692+
@constraint(m, x p)
693+
@constraint(m, x q)
694+
@objective(m, Min, x)
695+
optimize!(m)
696+
@assert is_solved_and_feasible(m)
697+
698+
function get_sensitivity(m, xᵢ, pᵢ)
699+
DiffOpt.empty_input_sensitivities!(m)
700+
@test is_empty(unsafe_backend(m).optimizer.input_cache)
701+
if !isnothing(unsafe_backend(m).optimizer.diff) &&
702+
!isnothing(unsafe_backend(m).optimizer.diff.model.input_cache)
703+
@test is_empty(unsafe_backend(m).optimizer.diff.model.input_cache)
704+
end
705+
MOI.set(
706+
m,
707+
DiffOpt.ForwardConstraintSet(),
708+
ParameterRef(pᵢ),
709+
Parameter(1.0),
710+
)
711+
DiffOpt.forward_differentiate!(m)
712+
return MOI.get(m, DiffOpt.ForwardVariablePrimal(), xᵢ)
713+
end
714+
715+
sp1 = get_sensitivity(m, x, p)
716+
sp2 = get_sensitivity(m, x, q)
717+
sp3 = get_sensitivity(m, x, p)
718+
@test sp1 sp3
719+
@test sp2 sp3
720+
return
721+
end
722+
674723
end # module
675724

676725
TestParameters.runtests()

0 commit comments

Comments
 (0)