Skip to content

Commit cd7248f

Browse files
authored
[Nonlinear.ReverseAD] disable linearity detection for subexpression if no hessian (#2738)
1 parent 2bd236f commit cd7248f

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
6969
d.data.expressions[k],
7070
d.subexpression_linearity,
7171
moi_index_to_consecutive_index,
72+
d.want_hess,
7273
)
7374
d.subexpressions[k] = subex
7475
d.subexpression_linearity[k] = subex.linearity

src/Nonlinear/ReverseAD/types.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@ struct _SubexpressionStorage
1818
expr::Nonlinear.Expression,
1919
subexpression_linearity,
2020
moi_index_to_consecutive_index,
21+
want_hess::Bool,
2122
)
2223
nodes =
2324
_replace_moi_variables(expr.nodes, moi_index_to_consecutive_index)
2425
adj = Nonlinear.adjacency_matrix(nodes)
2526
N = length(nodes)
26-
linearity = _classify_linearity(nodes, adj, subexpression_linearity)
27+
linearity = if want_hess
28+
_classify_linearity(nodes, adj, subexpression_linearity)[1]
29+
else
30+
NONLINEAR
31+
end
2732
return new(
2833
nodes,
2934
adj,
@@ -32,7 +37,7 @@ struct _SubexpressionStorage
3237
zeros(N), # partials_storage,
3338
zeros(N), # reverse_storage,
3439
Float64[],
35-
linearity[1],
40+
linearity,
3641
)
3742
end
3843
end

test/Nonlinear/ReverseAD.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,20 @@ function test_linearity()
630630
return
631631
end
632632

633+
function test_linearity_no_hess()
634+
x = MOI.VariableIndex(1)
635+
model = Nonlinear.Model()
636+
ex = Nonlinear.add_expression(model, :($x + 1))
637+
Nonlinear.set_objective(model, ex)
638+
evaluator = Nonlinear.Evaluator(model, Nonlinear.SparseReverseMode(), [x])
639+
MOI.initialize(evaluator, [:Grad, :Jac])
640+
# We initialized without the need for the hessian so
641+
# the linearity shouldn't be computed.
642+
@test only(evaluator.backend.subexpressions).linearity ==
643+
ReverseAD.NONLINEAR
644+
return
645+
end
646+
633647
function test_dual_forward()
634648
x = MOI.VariableIndex(1)
635649
y = MOI.VariableIndex(2)

0 commit comments

Comments
 (0)