Skip to content

Commit e668506

Browse files
expose factorization through as MOI.AbstractModelAttribute
1 parent 06f9110 commit e668506

File tree

5 files changed

+100
-72
lines changed

5 files changed

+100
-72
lines changed

src/NonLinearProgram/NonLinearProgram.jl

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ get_num_params(model::Model) = get_num_params(model.model)
372372

373373
function _cache_evaluator!(model::Model)
374374
form = model.model
375-
# Retrieve and sort primal variables by index
375+
# Retrieve and sort primal variables by NLP index
376376
params = sort(all_params(form); by = x -> x.value)
377377
primal_vars = sort(all_primal_vars(form); by = x -> x.value)
378378
num_primal = length(primal_vars)
@@ -389,7 +389,7 @@ function _cache_evaluator!(model::Model)
389389
num_low = length(has_low)
390390
num_up = length(has_up)
391391

392-
# Create unified dual mapping
392+
# Create unified dual mapping from constraint index to NLP index
393393
dual_mapping = Vector{Int}(undef, form.num_constraints)
394394
for (ci, cni) in form.constraints_2_nlp_index
395395
dual_mapping[ci.value] = cni.value
@@ -437,9 +437,6 @@ end
437437
function DiffOpt.forward_differentiate!(
438438
model::Model;
439439
tol = 1e-6,
440-
st = 1e-6,
441-
max_corrections = 50,
442-
allow_inertia_correction = true,
443440
)
444441
model.diff_time = @elapsed begin
445442
cache = _cache_evaluator!(model)
@@ -448,7 +445,7 @@ function DiffOpt.forward_differentiate!(
448445
Δp = zeros(length(cache.params))
449446
for (i, var_idx) in enumerate(cache.params)
450447
ky = form.var2ci[var_idx]
451-
if haskey(model.input_cache.dp, ky)
448+
if haskey(model.input_cache.dp, ky) # only for set sensitivities
452449
Δp[i] = model.input_cache.dp[ky]
453450
end
454451
end
@@ -457,9 +454,6 @@ function DiffOpt.forward_differentiate!(
457454
Δs = compute_sensitivity(
458455
model;
459456
tol = tol,
460-
st = st,
461-
max_corrections = max_corrections,
462-
allow_inertia_correction = allow_inertia_correction,
463457
)
464458

465459
# Extract primal and dual sensitivities
@@ -477,9 +471,6 @@ end
477471
function DiffOpt.reverse_differentiate!(
478472
model::Model;
479473
tol = 1e-6,
480-
st = 1e-6,
481-
max_corrections = 50,
482-
allow_inertia_correction = true,
483474
)
484475
model.diff_time = @elapsed begin
485476
cache = _cache_evaluator!(model)
@@ -489,9 +480,6 @@ function DiffOpt.reverse_differentiate!(
489480
Δs = compute_sensitivity(
490481
model;
491482
tol = tol,
492-
st = st,
493-
max_corrections = max_corrections,
494-
allow_inertia_correction = allow_inertia_correction,
495483
)
496484
num_primal = length(cache.primal_vars)
497485
# Fetch primal sensitivities

src/NonLinearProgram/nlp_utilities.jl

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -388,44 +388,6 @@ function build_sensitivity_matrices(
388388
return M, N
389389
end
390390

391-
"""
392-
inertia_corrector_factorization(M::SparseMatrixCSC, num_w, num_cons; st=1e-6, max_corrections=50)
393-
394-
Inertia correction for the factorization of the KKT matrix. Sparse version.
395-
"""
396-
function inertia_corrector_factorization(
397-
M::SparseMatrixCSC,
398-
num_w,
399-
num_cons;
400-
st = 1e-6,
401-
max_corrections = 50,
402-
allow_inertia_correction = true,
403-
)
404-
# Factorization
405-
K = lu(M; check = false)
406-
# Inertia correction
407-
status = K.status
408-
num_c = 0
409-
diag_mat = ones(size(M, 1))
410-
diag_mat[num_w+1:num_w+num_cons] .= -1
411-
diag_mat = SparseArrays.spdiagm(diag_mat)
412-
if status == 1
413-
@assert allow_inertia_correction "Inertia correction needed but not allowed"
414-
@info "Inertia correction needed"
415-
end
416-
while status == 1 && num_c < max_corrections
417-
M = M + st * diag_mat
418-
K = lu(M; check = false)
419-
status = K.status
420-
num_c += 1
421-
end
422-
if status != 0
423-
@warn "Inertia correction failed"
424-
return nothing
425-
end
426-
return K
427-
end
428-
429391
"""
430392
compute_derivatives_no_relax(model::Model, cons::Vector{MOI.Nonlinear.ConstraintIndex},
431393
_X::AbstractVector, _V_L::AbstractVector, _X_L::AbstractVector, _V_U::AbstractVector, _X_U::AbstractVector, leq_locations::Vector{Z}, geq_locations::Vector{Z}, ineq_locations::Vector{Z},
@@ -447,9 +409,6 @@ function compute_derivatives_no_relax(
447409
ineq_locations::Vector{Z},
448410
has_up::Vector{Z},
449411
has_low::Vector{Z};
450-
st = 1e-6,
451-
max_corrections = 50,
452-
allow_inertia_correction = true,
453412
) where {Z<:Integer}
454413
M, N = build_sensitivity_matrices(
455414
model,
@@ -470,13 +429,10 @@ function compute_derivatives_no_relax(
470429
num_vars = get_num_primal_vars(model)
471430
num_cons = get_num_constraints(model)
472431
num_ineq = length(ineq_locations)
473-
K = inertia_corrector_factorization(
432+
K = model.input_cache.factorization(
474433
M,
475434
num_vars + num_ineq,
476-
num_cons;
477-
st = st,
478-
max_corrections = max_corrections,
479-
allow_inertia_correction = allow_inertia_correction,
435+
num_cons
480436
) # Factorization
481437
if isnothing(K)
482438
return zeros(size(M, 1), size(N, 2)), K, N
@@ -499,9 +455,6 @@ Compute the sensitivity of the solution given sensitivity of the parameters (Δp
499455
function compute_sensitivity(
500456
model::Model;
501457
tol = 1e-6,
502-
st = 1e-6,
503-
max_corrections = 50,
504-
allow_inertia_correction = true,
505458
)
506459
# Solution and bounds
507460
X,
@@ -529,10 +482,7 @@ function compute_sensitivity(
529482
geq_locations,
530483
ineq_locations,
531484
has_up,
532-
has_low;
533-
st = st,
534-
max_corrections = max_corrections,
535-
allow_inertia_correction = allow_inertia_correction,
485+
has_low
536486
)
537487
## Adjust signs based on JuMP convention
538488
num_vars = get_num_primal_vars(model)

src/diff_opt.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,50 @@
1111

1212
const MOIDD = MOI.Utilities.DoubleDicts
1313

14+
"""
15+
LuFactorizationWithInertiaCorrection{T<:Real}
16+
17+
A callable struct to store the parameters for the inertia correction in the
18+
Lu-factorization. If no inertia correction is needed, it only performs the LU
19+
factorization.
20+
"""
21+
struct LuFactorizationWithInertiaCorrection{T<:Real} <: Function
22+
st::T
23+
max_corrections::Int
24+
end
25+
function LuFactorizationWithInertiaCorrection(; st::T = 1e-6, max_corrections::Int = 50) where T
26+
return LuFactorizationWithInertiaCorrection{T}(st, max_corrections)
27+
end
28+
29+
function (lu_struct::LuFactorizationWithInertiaCorrection)(
30+
M::SparseArrays.SparseMatrixCSC,
31+
num_w,
32+
num_cons
33+
)
34+
# Factorization
35+
K = SparseArrays.lu(M; check = false)
36+
# Inertia correction
37+
status = K.status
38+
if status == 1
39+
@info "Inertia correction needed"
40+
num_c = 0
41+
diag_mat = ones(size(M, 1))
42+
diag_mat[num_w+1:num_w+num_cons] .= -1
43+
diag_mat = SparseArrays.spdiagm(diag_mat)
44+
while status == 1 && num_c < lu_struct.max_corrections
45+
M = M + lu_struct.st * diag_mat
46+
K = lu(M; check = false)
47+
status = K.status
48+
num_c += 1
49+
end
50+
if status != 0
51+
@warn "Inertia correction failed"
52+
return nothing
53+
end
54+
end
55+
return K
56+
end
57+
1458
Base.@kwdef mutable struct InputCache
1559
dx::Dict{MOI.VariableIndex,Float64} = Dict{MOI.VariableIndex,Float64}()# dz for QP
1660
dp::Dict{MOI.ConstraintIndex,Float64} = Dict{MOI.ConstraintIndex,Float64}() # Specifically for NonLinearProgram
@@ -28,6 +72,7 @@ Base.@kwdef mutable struct InputCache
2872
vector_constraints::MOIDD.DoubleDict{MOI.VectorAffineFunction{Float64}} =
2973
MOIDD.DoubleDict{MOI.VectorAffineFunction{Float64}}() # also includes G for QPs
3074
objective::Union{Nothing,MOI.AbstractScalarFunction} = nothing
75+
factorization::Function = LuFactorizationWithInertiaCorrection()
3176
end
3277

3378
function Base.empty!(cache::InputCache)
@@ -37,6 +82,7 @@ function Base.empty!(cache::InputCache)
3782
empty!(cache.scalar_constraints)
3883
empty!(cache.vector_constraints)
3984
cache.objective = nothing
85+
cache.factorization = LuFactorizationWithInertiaCorrection()
4086
return
4187
end
4288

@@ -92,6 +138,29 @@ where `x` and `y` are the relevant `MOI.VariableIndex`.
92138
"""
93139
struct ForwardObjectiveFunction <: MOI.AbstractModelAttribute end
94140

141+
"""
142+
MFactorization <: MOI.AbstractModelAttribute
143+
144+
A `MOI.AbstractModelAttribute` to set which factorization function to use for the
145+
implict function diferentiation needed to compute the sensitivities for
146+
`NonLinearProgram` models.
147+
148+
The function will be called with the following signature:
149+
```julia
150+
function factorization(M::SparseMatrixCSC{T<Real}, # The matrix to factorize
151+
num_w::Int, # Number of primal and slack variables (can be ignored - useful for inertia correction)
152+
num_cons::Int, # The number of constraints (can be ignored - useful for inertia correction)
153+
)
154+
```
155+
156+
Can be set by the user to use a custom factorization function:
157+
158+
```julia
159+
MOI.set(model, DiffOpt.MFactorization(), factorization)
160+
```
161+
"""
162+
struct MFactorization <: MOI.AbstractModelAttribute end
163+
95164
"""
96165
ForwardConstraintFunction <: MOI.AbstractConstraintAttribute
97166
@@ -346,6 +415,15 @@ function MOI.set(model::AbstractModel, ::ForwardObjectiveFunction, objective)
346415
return
347416
end
348417

418+
function MOI.set(
419+
model::AbstractModel,
420+
::MFactorization,
421+
factorization::Function,
422+
)
423+
model.input_cache.factorization = factorization
424+
return
425+
end
426+
349427
function MOI.set(
350428
model::AbstractModel,
351429
::ReverseVariablePrimal,

src/jump_moi_overloads.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ function MOI.set(
2121
return MOI.set(model, attr, JuMP.moi_function(func))
2222
end
2323

24+
function MOI.set(
25+
model::JuMP.Model,
26+
attr::MFactorization,
27+
factorization::Function,
28+
)
29+
return MOI.set(JuMP.backend(model), attr, factorization)
30+
end
31+
2432
function MOI.set(
2533
model::JuMP.Model,
2634
attr::ForwardObjectiveFunction,

src/moi_wrapper.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,21 +512,22 @@ function MOI.set(model::Optimizer, ::ModelConstructor, model_constructor)
512512
return
513513
end
514514

515-
function reverse_differentiate!(model::Optimizer; kwargs...)
515+
function reverse_differentiate!(model::Optimizer)
516516
st = MOI.get(model.optimizer, MOI.TerminationStatus())
517517
if !in(st, (MOI.LOCALLY_SOLVED, MOI.OPTIMAL))
518518
error(
519519
"Trying to compute the reverse differentiation on a model with termination status $(st)",
520520
)
521521
end
522522
diff = _diff(model)
523+
MOI.set(diff, MFactorization(), model.input_cache.factorization)
523524
for (vi, value) in model.input_cache.dx
524525
MOI.set(diff, ReverseVariablePrimal(), model.index_map[vi], value)
525526
end
526527
for (vi, value) in model.input_cache.dy
527528
MOI.set(diff, ReverseConstraintDual(), model.index_map[vi], value)
528529
end
529-
return reverse_differentiate!(diff; kwargs...)
530+
return reverse_differentiate!(diff)
530531
end
531532

532533
function _copy_forward_in_constraint(diff, index_map, con_map, constraints)
@@ -541,14 +542,15 @@ function _copy_forward_in_constraint(diff, index_map, con_map, constraints)
541542
return
542543
end
543544

544-
function forward_differentiate!(model::Optimizer; kwargs...)
545+
function forward_differentiate!(model::Optimizer)
545546
st = MOI.get(model.optimizer, MOI.TerminationStatus())
546547
if !in(st, (MOI.LOCALLY_SOLVED, MOI.OPTIMAL))
547548
error(
548549
"Trying to compute the forward differentiation on a model with termination status $(st)",
549550
)
550551
end
551552
diff = _diff(model)
553+
MOI.set(diff, MFactorization(), model.input_cache.factorization)
552554
if model.input_cache.objective !== nothing
553555
MOI.set(
554556
diff,
@@ -580,7 +582,7 @@ function forward_differentiate!(model::Optimizer; kwargs...)
580582
diff.model.input_cache.dp[model.index_map[vi]] = value
581583
end
582584
end
583-
return forward_differentiate!(diff; kwargs...)
585+
return forward_differentiate!(diff)
584586
end
585587

586588
function empty_input_sensitivities!(model::Optimizer)
@@ -673,6 +675,8 @@ end
673675

674676
MOI.supports(::Optimizer, ::ForwardObjectiveFunction) = true
675677

678+
MOI.supports(::Optimizer, ::MFactorization) = true
679+
676680
function MOI.get(model::Optimizer, ::ForwardObjectiveFunction)
677681
return model.input_cache.objective
678682
end

0 commit comments

Comments
 (0)