Skip to content

Commit bd0bf71

Browse files
authored
[Nonlinear.ReverseAD] simplify arguments of _forward_eval_ϵ (#2736)
1 parent be6eaba commit bd0bf71

File tree

2 files changed

+13
-46
lines changed

2 files changed

+13
-46
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ end
116116

117117
function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
118118
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
119-
input_ϵ = _reinterpret_unsafe(T, d.input_ϵ)
120119
fill!(d.output_ϵ, 0.0)
121120
output_ϵ = _reinterpret_unsafe(T, d.output_ϵ)
122121
subexpr_forward_values_ϵ =
@@ -126,22 +125,10 @@ function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
126125
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
127126
d,
128127
subexpr,
129-
_reinterpret_unsafe(T, d.storage_ϵ),
130128
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
131-
input_ϵ,
132-
subexpr_forward_values_ϵ,
133-
d.data.operators,
134129
)
135130
end
136-
_forward_eval_ϵ(
137-
d,
138-
ex,
139-
_reinterpret_unsafe(T, d.storage_ϵ),
140-
_reinterpret_unsafe(T, d.partials_storage_ϵ),
141-
input_ϵ,
142-
subexpr_forward_values_ϵ,
143-
d.data.operators,
144-
)
131+
_forward_eval_ϵ(d, ex, _reinterpret_unsafe(T, d.partials_storage_ϵ))
145132
# do a reverse pass
146133
subexpr_reverse_values_ϵ =
147134
_reinterpret_unsafe(T, d.subexpression_reverse_values_ϵ)
@@ -180,11 +167,7 @@ end
180167
_forward_eval_ϵ(
181168
d::NLPEvaluator,
182169
ex::Union{_FunctionStorage,_SubexpressionStorage},
183-
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
184170
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
185-
x_values_ϵ,
186-
subexpression_values_ϵ,
187-
user_operators::Nonlinear.OperatorRegistry,
188171
) where {N,T}
189172
190173
Evaluate the directional derivatives of the expression tree in `ex`.
@@ -198,15 +181,15 @@ This assumes that `_reverse_model(d, x)` has already been called.
198181
function _forward_eval_ϵ(
199182
d::NLPEvaluator,
200183
ex::Union{_FunctionStorage,_SubexpressionStorage},
201-
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
202-
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
203-
x_values_ϵ,
204-
subexpression_values_ϵ,
205-
user_operators::Nonlinear.OperatorRegistry,
206-
) where {N,T}
184+
partials_storage_ϵ::AbstractVector{P},
185+
) where {N,T,P<:ForwardDiff.Partials{N,T}}
186+
storage_ϵ = _reinterpret_unsafe(P, d.storage_ϵ)
187+
x_values_ϵ = reinterpret(P, d.input_ϵ)
188+
subexpression_values_ϵ =
189+
_reinterpret_unsafe(P, d.subexpression_forward_values_ϵ)
207190
@assert length(storage_ϵ) >= length(ex.nodes)
208191
@assert length(partials_storage_ϵ) >= length(ex.nodes)
209-
zero_ϵ = zero(ForwardDiff.Partials{N,T})
192+
zero_ϵ = zero(P)
210193
# ex.nodes is already in order such that parents always appear before children
211194
# so a backwards pass through ex.nodes is a forward pass through the tree
212195
children_arr = SparseArrays.rowvals(ex.adj)
@@ -339,16 +322,16 @@ function _forward_eval_ϵ(
339322
n_children,
340323
)
341324
has_hessian = Nonlinear.eval_multivariate_hessian(
342-
user_operators,
343-
user_operators.multivariate_operators[node.index],
325+
d.data.operators,
326+
d.data.operators.multivariate_operators[node.index],
344327
H,
345328
f_input,
346329
)
347330
# This might be `false` if we extend this code to all
348331
# multivariate functions.
349332
@assert has_hessian
350333
for col in 1:n_children
351-
dual = zero(ForwardDiff.Partials{N,T})
334+
dual = zero(P)
352335
for row in 1:n_children
353336
# Make sure we get the lower-triangular component.
354337
h = row >= col ? H[row, col] : H[col, row]
@@ -366,7 +349,7 @@ function _forward_eval_ϵ(
366349
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
367350
@inbounds child_idx = children_arr[ex.adj.colptr[k]]
368351
f′′ = Nonlinear.eval_univariate_hessian(
369-
user_operators,
352+
d.data.operators,
370353
node.index,
371354
ex.forward_storage[child_idx],
372355
)

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
349349
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
350350
d,
351351
subexpr,
352-
reinterpret(T, d.storage_ϵ),
353352
reinterpret(T, subexpr.partials_storage_ϵ),
354-
input_ϵ,
355-
subexpr_forward_values_ϵ,
356-
d.data.operators,
357353
)
358354
end
359355
# we only need to do one reverse pass through the subexpressions as well
@@ -366,11 +362,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
366362
_forward_eval_ϵ(
367363
d,
368364
something(d.objective),
369-
reinterpret(T, d.storage_ϵ),
370365
reinterpret(T, d.partials_storage_ϵ),
371-
input_ϵ,
372-
subexpr_forward_values_ϵ,
373-
d.data.operators,
374366
)
375367
_reverse_eval_ϵ(
376368
output_ϵ,
@@ -384,15 +376,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
384376
)
385377
end
386378
for (i, con) in enumerate(d.constraints)
387-
_forward_eval_ϵ(
388-
d,
389-
con,
390-
reinterpret(T, d.storage_ϵ),
391-
reinterpret(T, d.partials_storage_ϵ),
392-
input_ϵ,
393-
subexpr_forward_values_ϵ,
394-
d.data.operators,
395-
)
379+
_forward_eval_ϵ(d, con, reinterpret(T, d.partials_storage_ϵ))
396380
_reverse_eval_ϵ(
397381
output_ϵ,
398382
con,

0 commit comments

Comments
 (0)