diff --git a/src/Nonlinear/ReverseAD/mathoptinterface_api.jl b/src/Nonlinear/ReverseAD/mathoptinterface_api.jl index 787eb62673..47ee17393d 100644 --- a/src/Nonlinear/ReverseAD/mathoptinterface_api.jl +++ b/src/Nonlinear/ReverseAD/mathoptinterface_api.jl @@ -69,6 +69,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol}) d.data.expressions[k], d.subexpression_linearity, moi_index_to_consecutive_index, + d.want_hess, ) d.subexpressions[k] = subex d.subexpression_linearity[k] = subex.linearity diff --git a/src/Nonlinear/ReverseAD/types.jl b/src/Nonlinear/ReverseAD/types.jl index 6526a2d461..fc599accb0 100644 --- a/src/Nonlinear/ReverseAD/types.jl +++ b/src/Nonlinear/ReverseAD/types.jl @@ -18,12 +18,17 @@ struct _SubexpressionStorage expr::Nonlinear.Expression, subexpression_linearity, moi_index_to_consecutive_index, + want_hess::Bool, ) nodes = _replace_moi_variables(expr.nodes, moi_index_to_consecutive_index) adj = Nonlinear.adjacency_matrix(nodes) N = length(nodes) - linearity = _classify_linearity(nodes, adj, subexpression_linearity) + linearity = if want_hess + _classify_linearity(nodes, adj, subexpression_linearity)[1] + else + NONLINEAR + end return new( nodes, adj, @@ -32,7 +37,7 @@ struct _SubexpressionStorage zeros(N), # partials_storage, zeros(N), # reverse_storage, Float64[], - linearity[1], + linearity, ) end end diff --git a/test/Nonlinear/ReverseAD.jl b/test/Nonlinear/ReverseAD.jl index 43fbb28b4f..75a11aae66 100644 --- a/test/Nonlinear/ReverseAD.jl +++ b/test/Nonlinear/ReverseAD.jl @@ -630,6 +630,20 @@ function test_linearity() return end +function test_linearity_no_hess() + x = MOI.VariableIndex(1) + model = Nonlinear.Model() + ex = Nonlinear.add_expression(model, :($x + 1)) + Nonlinear.set_objective(model, ex) + evaluator = Nonlinear.Evaluator(model, Nonlinear.SparseReverseMode(), [x]) + MOI.initialize(evaluator, [:Grad, :Jac]) + # We initialized without the need for the hessian so + # the linearity shouldn't be computed. + @test only(evaluator.backend.subexpressions).linearity == + ReverseAD.NONLINEAR + return +end + function test_dual_forward() x = MOI.VariableIndex(1) y = MOI.VariableIndex(2)