Skip to content

Commit 1c0630e

Browse files
authored
Use _SubexpressionStorage inside _FunctionStorage (#3)
1 parent 7ff519d commit 1c0630e

File tree

4 files changed

+105
-116
lines changed

4 files changed

+105
-116
lines changed

src/forward_over_reverse.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function _eval_hessian(
3838
scale::Float64,
3939
nzcount::Int,
4040
)::Int
41-
if ex.linearity == LINEAR
41+
if ex.expr.linearity == LINEAR
4242
@assert length(ex.hess_I) == 0
4343
return 0
4444
end
@@ -128,13 +128,9 @@ function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
128128
_reinterpret_unsafe(T, d.subexpression_forward_values_ϵ)
129129
for i in ex.dependent_subexpressions
130130
subexpr = d.subexpressions[i]
131-
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
132-
d,
133-
subexpr,
134-
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
135-
)
131+
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(d, subexpr, T)
136132
end
137-
_forward_eval_ϵ(d, ex, _reinterpret_unsafe(T, d.partials_storage_ϵ))
133+
_forward_eval_ϵ(d, ex.expr, T)
138134
# do a reverse pass
139135
subexpr_reverse_values_ϵ =
140136
_reinterpret_unsafe(T, d.subexpression_reverse_values_ϵ)
@@ -144,9 +140,8 @@ function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
144140
end
145141
_reverse_eval_ϵ(
146142
output_ϵ,
147-
ex,
143+
ex.expr,
148144
_reinterpret_unsafe(T, d.storage_ϵ),
149-
_reinterpret_unsafe(T, d.partials_storage_ϵ),
150145
d.subexpression_reverse_values,
151146
subexpr_reverse_values_ϵ,
152147
1.0,
@@ -159,7 +154,6 @@ function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
159154
output_ϵ,
160155
subexpr,
161156
_reinterpret_unsafe(T, d.storage_ϵ),
162-
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
163157
d.subexpression_reverse_values,
164158
subexpr_reverse_values_ϵ,
165159
d.subexpression_reverse_values[j],
@@ -173,8 +167,8 @@ end
173167
_forward_eval_ϵ(
174168
d::NLPEvaluator,
175169
ex::Union{_FunctionStorage,_SubexpressionStorage},
176-
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
177-
) where {N,T}
170+
::Type{P},
171+
) where {N,T,P<:ForwardDiff.Partials{N,T}}
178172
179173
Evaluate the directional derivatives of the expression tree in `ex`.
180174
@@ -186,10 +180,11 @@ This assumes that `_reverse_model(d, x)` has already been called.
186180
"""
187181
function _forward_eval_ϵ(
188182
d::NLPEvaluator,
189-
ex::Union{_FunctionStorage,_SubexpressionStorage},
190-
partials_storage_ϵ::AbstractVector{P},
183+
ex::_SubexpressionStorage,
184+
::Type{P},
191185
) where {N,T,P<:ForwardDiff.Partials{N,T}}
192186
storage_ϵ = _reinterpret_unsafe(P, d.storage_ϵ)
187+
partials_storage_ϵ = _reinterpret_unsafe(P, ex.partials_storage_ϵ)
193188
x_values_ϵ = _reinterpret_unsafe(P, d.input_ϵ)
194189
subexpression_values_ϵ =
195190
_reinterpret_unsafe(P, d.subexpression_forward_values_ϵ)
@@ -370,14 +365,15 @@ end
370365
# to compute hessian-vector products.
371366
function _reverse_eval_ϵ(
372367
output_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
373-
ex::Union{_FunctionStorage,_SubexpressionStorage},
368+
ex::_SubexpressionStorage,
374369
reverse_storage_ϵ,
375-
partials_storage_ϵ,
376370
subexpression_output,
377371
subexpression_output_ϵ,
378372
scale::T,
379373
scale_ϵ::ForwardDiff.Partials{N,T},
380374
) where {N,T}
375+
partials_storage_ϵ =
376+
_reinterpret_unsafe(ForwardDiff.Partials{N,T}, ex.partials_storage_ϵ)
381377
@assert length(reverse_storage_ϵ) >= length(ex.nodes)
382378
@assert length(partials_storage_ϵ) >= length(ex.nodes)
383379
if ex.nodes[1].type == Nonlinear.NODE_VARIABLE

src/mathoptinterface_api.jl

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
6565
for k in d.subexpression_order
6666
# Only load expressions which actually are used
6767
d.subexpression_forward_values[k] = NaN
68-
subex = _SubexpressionStorage(
69-
d.data.expressions[k],
70-
d.subexpression_linearity,
68+
expr = d.data.expressions[k]
69+
subex, _ = _subexpression_and_linearity(
70+
expr,
7171
moi_index_to_consecutive_index,
72-
d.want_hess,
72+
Float64[],
73+
d,
7374
)
7475
d.subexpressions[k] = subex
7576
d.subexpression_linearity[k] = subex.linearity
@@ -101,43 +102,54 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
101102
end
102103
end
103104
max_chunk = 1
105+
shared_partials_storage_ϵ = Float64[]
104106
if d.data.objective !== nothing
107+
expr = something(d.data.objective)
108+
subexpr, linearity = _subexpression_and_linearity(
109+
expr,
110+
moi_index_to_consecutive_index,
111+
shared_partials_storage_ϵ,
112+
d,
113+
)
105114
objective = _FunctionStorage(
106-
main_expressions[1],
107-
something(d.data.objective).values,
115+
subexpr,
108116
N,
109117
coloring_storage,
110118
d.want_hess,
111119
d.subexpressions,
112120
individual_order[1],
113-
d.subexpression_linearity,
114121
subexpression_edgelist,
115122
subexpression_variables,
116-
moi_index_to_consecutive_index,
123+
linearity,
117124
)
118-
max_expr_length = max(max_expr_length, length(objective.nodes))
125+
max_expr_length = max(max_expr_length, length(expr.nodes))
119126
max_chunk = max(max_chunk, size(objective.seed_matrix, 2))
120127
d.objective = objective
121128
end
122129
for (k, (_, constraint)) in enumerate(d.data.constraints)
123130
idx = d.data.objective !== nothing ? k + 1 : k
131+
expr = constraint.expression
132+
subexpr, linearity = _subexpression_and_linearity(
133+
expr,
134+
moi_index_to_consecutive_index,
135+
shared_partials_storage_ϵ,
136+
d,
137+
)
124138
push!(
125139
d.constraints,
126140
_FunctionStorage(
127-
main_expressions[idx],
128-
constraint.expression.values,
141+
subexpr,
129142
N,
130143
coloring_storage,
131144
d.want_hess,
132145
d.subexpressions,
133146
individual_order[idx],
134-
d.subexpression_linearity,
135147
subexpression_edgelist,
136148
subexpression_variables,
137-
moi_index_to_consecutive_index,
149+
linearity,
138150
),
139151
)
140-
max_expr_length = max(max_expr_length, length(d.constraints[end].nodes))
152+
max_expr_length = max(max_expr_length, length(expr.nodes))
141153
max_chunk = max(max_chunk, size(d.constraints[end].seed_matrix, 2))
142154
end
143155
max_chunk = min(max_chunk, MAX_CHUNK)
@@ -146,7 +158,8 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
146158
d.input_ϵ = zeros(max_chunk * N)
147159
d.output_ϵ = zeros(max_chunk * N)
148160
#
149-
d.partials_storage_ϵ = zeros(max_chunk * max_expr_length)
161+
resize!(shared_partials_storage_ϵ, max_chunk * max_expr_length)
162+
fill!(shared_partials_storage_ϵ, 0.0)
150163
d.storage_ϵ = zeros(max_chunk * max_expr_with_sub_length)
151164
#
152165
len = max_chunk * length(d.subexpressions)
@@ -178,7 +191,7 @@ function MOI.eval_objective(d::NLPEvaluator, x)
178191
error("No nonlinear objective.")
179192
end
180193
_reverse_mode(d, x)
181-
return something(d.objective).forward_storage[1]
194+
return something(d.objective).expr.forward_storage[1]
182195
end
183196

184197
function MOI.eval_objective_gradient(d::NLPEvaluator, g, x)
@@ -194,7 +207,7 @@ end
194207
function MOI.eval_constraint(d::NLPEvaluator, g, x)
195208
_reverse_mode(d, x)
196209
for i in 1:length(d.constraints)
197-
g[i] = d.constraints[i].forward_storage[1]
210+
g[i] = d.constraints[i].expr.forward_storage[1]
198211
end
199212
return
200213
end
@@ -345,11 +358,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
345358
subexpr_forward_values_ϵ = reinterpret(T, d.subexpression_forward_values_ϵ)
346359
for i in d.subexpression_order
347360
subexpr = d.subexpressions[i]
348-
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
349-
d,
350-
subexpr,
351-
reinterpret(T, subexpr.partials_storage_ϵ),
352-
)
361+
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(d, subexpr, T)
353362
end
354363
# we only need to do one reverse pass through the subexpressions as well
355364
subexpr_reverse_values_ϵ = reinterpret(T, d.subexpression_reverse_values_ϵ)
@@ -358,29 +367,23 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
358367
fill!(d.storage_ϵ, 0.0)
359368
fill!(output_ϵ, zero(T))
360369
if d.objective !== nothing
361-
_forward_eval_ϵ(
362-
d,
363-
something(d.objective),
364-
reinterpret(T, d.partials_storage_ϵ),
365-
)
370+
_forward_eval_ϵ(d, something(d.objective).expr, T)
366371
_reverse_eval_ϵ(
367372
output_ϵ,
368-
something(d.objective),
369-
reinterpret(T, d.storage_ϵ),
370-
reinterpret(T, d.partials_storage_ϵ),
373+
something(d.objective).expr,
374+
_reinterpret_unsafe(T, d.storage_ϵ),
371375
d.subexpression_reverse_values,
372376
subexpr_reverse_values_ϵ,
373377
σ,
374378
zero(T),
375379
)
376380
end
377381
for (i, con) in enumerate(d.constraints)
378-
_forward_eval_ϵ(d, con, reinterpret(T, d.partials_storage_ϵ))
382+
_forward_eval_ϵ(d, con.expr, T)
379383
_reverse_eval_ϵ(
380384
output_ϵ,
381-
con,
385+
con.expr,
382386
reinterpret(T, d.storage_ϵ),
383-
reinterpret(T, d.partials_storage_ϵ),
384387
d.subexpression_reverse_values,
385388
subexpr_reverse_values_ϵ,
386389
μ[i],
@@ -394,7 +397,6 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
394397
output_ϵ,
395398
subexpr,
396399
reinterpret(T, d.storage_ϵ),
397-
reinterpret(T, subexpr.partials_storage_ϵ),
398400
d.subexpression_reverse_values,
399401
subexpr_reverse_values_ϵ,
400402
d.subexpression_reverse_values[j],

src/reverse_mode.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,20 @@ function _reverse_mode(d::NLPEvaluator, x)
3939
_forward_eval(d.subexpressions[k], d, x)
4040
end
4141
if d.objective !== nothing
42-
_forward_eval(d.objective::_FunctionStorage, d, x)
42+
_forward_eval(something(d.objective).expr, d, x)
4343
end
4444
for con in d.constraints
45-
_forward_eval(con, d, x)
45+
_forward_eval(con.expr, d, x)
4646
end
4747
# Phase II
4848
for k in d.subexpression_order
4949
_reverse_eval(d.subexpressions[k])
5050
end
5151
if d.objective !== nothing
52-
_reverse_eval(d.objective::_FunctionStorage)
52+
_reverse_eval(something(d.objective).expr)
5353
end
5454
for con in d.constraints
55-
_reverse_eval(con)
55+
_reverse_eval(con.expr)
5656
end
5757
# If a JuMP model uses the legacy nonlinear interface, then JuMP constructs
5858
# a NLPEvaluator at the start of a call to `JuMP.optimize!` and it passes in
@@ -81,7 +81,7 @@ end
8181

8282
"""
8383
_forward_eval(
84-
f::Union{_FunctionStorage,_SubexpressionStorage},
84+
f::_SubexpressionStorage,
8585
d::NLPEvaluator,
8686
x::AbstractVector{T},
8787
) where {T}
@@ -98,10 +98,7 @@ Forward-mode evaluation of an expression tree given in `f`.
9898
associate storage with each edge of the DAG.
9999
"""
100100
function _forward_eval(
101-
# !!! warning
102-
# This Union depends upon _FunctionStorage and _SubexpressionStorage
103-
# having similarly named fields.
104-
f::Union{_FunctionStorage,_SubexpressionStorage},
101+
f::_SubexpressionStorage,
105102
d::NLPEvaluator,
106103
x::AbstractVector{T},
107104
)::T where {T}
@@ -290,19 +287,14 @@ function _forward_eval(
290287
end
291288

292289
"""
293-
_reverse_eval(f::Union{_FunctionStorage,_SubexpressionStorage})
290+
_reverse_eval(f::_SubexpressionStorage)
294291
295292
Reverse-mode evaluation of an expression tree given in `f`.
296293
297294
* This function assumes `f.partials_storage` is already updated.
298295
* This function assumes that `f.reverse_storage` has been initialized with 0.0.
299296
"""
300-
function _reverse_eval(
301-
# !!! warning
302-
# This Union depends upon _FunctionStorage and _SubexpressionStorage
303-
# having similarly named fields.
304-
f::Union{_FunctionStorage,_SubexpressionStorage},
305-
)
297+
function _reverse_eval(f::_SubexpressionStorage)
306298
@assert length(f.reverse_storage) >= length(f.nodes)
307299
@assert length(f.partials_storage) >= length(f.nodes)
308300
# f.nodes is already in order such that parents always appear before
@@ -361,9 +353,15 @@ end
361353

362354
function _extract_reverse_pass_inner(
363355
output::AbstractVector{T},
364-
# !!! warning
365-
# This Union depends upon _FunctionStorage and _SubexpressionStorage
366-
# having similarly named fields.
356+
f::_FunctionStorage,
357+
subexpressions::AbstractVector{T},
358+
scale::T,
359+
) where {T}
360+
return _extract_reverse_pass_inner(output, f.expr, subexpressions, scale)
361+
end
362+
363+
function _extract_reverse_pass_inner(
364+
output::AbstractVector{T},
367365
f::Union{_FunctionStorage,_SubexpressionStorage},
368366
subexpressions::AbstractVector{T},
369367
scale::T,

0 commit comments

Comments
 (0)