Skip to content

Commit 1409152

Browse files
blegatodow
andauthored
[Nonlinear.ReverseAD] remove dynamic dispatch in Hessian evaluation (#2740)
Co-authored-by: odow <o.dowson@gmail.com>
1 parent bd0bf71 commit 1409152

File tree

3 files changed

+55
-44
lines changed

3 files changed

+55
-44
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66

77
const TAG = :ReverseAD
88

9+
"""
10+
const MAX_CHUNK::Int = 10
11+
12+
An upper bound on the chunk sie for forward-over-reverse. Increasing this could
13+
improve performance at the cost of extra memory allocation. It has been 10 for a
14+
long time, and nobody seems to have complained.
15+
"""
16+
const MAX_CHUNK = 10
17+
918
"""
1019
_eval_hessian(
1120
d::NLPEvaluator,
@@ -23,46 +32,30 @@ Returns the number of non-zeros in the computed Hessian, which will be used to
2332
update the offset for the next call.
2433
"""
2534
function _eval_hessian(
26-
d::NLPEvaluator,
27-
f::_FunctionStorage,
28-
H::AbstractVector{Float64},
29-
λ::Float64,
30-
offset::Int,
31-
)::Int
32-
chunk = min(size(f.seed_matrix, 2), d.max_chunk)
33-
# As a performance optimization, skip dynamic dispatch if the chunk is 1.
34-
if chunk == 1
35-
return _eval_hessian_inner(d, f, H, λ, offset, Val(1))
36-
else
37-
return _eval_hessian_inner(d, f, H, λ, offset, Val(chunk))
38-
end
39-
end
40-
41-
function _eval_hessian_inner(
4235
d::NLPEvaluator,
4336
ex::_FunctionStorage,
4437
H::AbstractVector{Float64},
4538
scale::Float64,
4639
nzcount::Int,
47-
::Val{CHUNK},
48-
) where {CHUNK}
40+
)::Int
4941
if ex.linearity == LINEAR
5042
@assert length(ex.hess_I) == 0
5143
return 0
5244
end
45+
chunk = min(size(ex.seed_matrix, 2), d.max_chunk)
5346
Coloring.prepare_seed_matrix!(ex.seed_matrix, ex.rinfo)
5447
# Compute hessian-vector products
5548
num_products = size(ex.seed_matrix, 2) # number of hessian-vector products
56-
num_chunks = div(num_products, CHUNK)
49+
num_chunks = div(num_products, chunk)
5750
@assert size(ex.seed_matrix, 1) == length(ex.rinfo.local_indices)
58-
for offset in 1:CHUNK:(CHUNK*num_chunks)
59-
_eval_hessian_chunk(d, ex, offset, CHUNK, Val(CHUNK))
51+
for offset in 1:chunk:(chunk*num_chunks)
52+
_eval_hessian_chunk(d, ex, offset, chunk, chunk)
6053
end
6154
# leftover chunk
62-
remaining = num_products - CHUNK * num_chunks
55+
remaining = num_products - chunk * num_chunks
6356
if remaining > 0
64-
offset = CHUNK * num_chunks + 1
65-
_eval_hessian_chunk(d, ex, offset, remaining, Val(CHUNK))
57+
offset = chunk * num_chunks + 1
58+
_eval_hessian_chunk(d, ex, offset, remaining, chunk)
6659
end
6760
want, got = nzcount + length(ex.hess_I), length(H)
6861
if want > got
@@ -90,32 +83,45 @@ function _eval_hessian_chunk(
9083
ex::_FunctionStorage,
9184
offset::Int,
9285
chunk::Int,
93-
::Val{CHUNK},
94-
) where {CHUNK}
86+
chunk_size::Int,
87+
)
9588
for r in eachindex(ex.rinfo.local_indices)
9689
# set up directional derivatives
9790
@inbounds idx = ex.rinfo.local_indices[r]
9891
# load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
9992
for s in 1:chunk
100-
# If `chunk < CHUNK`, leaves junk in the unused components
101-
d.input_ϵ[(idx-1)*CHUNK+s] = ex.seed_matrix[r, offset+s-1]
93+
# If `chunk < chunk_size`, leaves junk in the unused components
94+
d.input_ϵ[(idx-1)*chunk_size+s] = ex.seed_matrix[r, offset+s-1]
10295
end
10396
end
104-
_hessian_slice_inner(d, ex, Val(CHUNK))
97+
_hessian_slice_inner(d, ex, chunk_size)
10598
fill!(d.input_ϵ, 0.0)
10699
# collect directional derivatives
107100
for r in eachindex(ex.rinfo.local_indices)
108101
@inbounds idx = ex.rinfo.local_indices[r]
109102
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
110103
for s in 1:chunk
111-
ex.seed_matrix[r, offset+s-1] = d.output_ϵ[(idx-1)*CHUNK+s]
104+
ex.seed_matrix[r, offset+s-1] = d.output_ϵ[(idx-1)*chunk_size+s]
112105
end
113106
end
114107
return
115108
end
116109

117-
function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
118-
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
110+
# A wrapper function to avoid dynamic dispatch.
111+
function _generate_hessian_slice_inner()
112+
exprs = map(1:MAX_CHUNK) do id
113+
T = ForwardDiff.Partials{id,Float64}
114+
return :(return _hessian_slice_inner(d, ex, $T))
115+
end
116+
return MOI.Nonlinear._create_binary_switch(1:MAX_CHUNK, exprs)
117+
end
118+
119+
@eval function _hessian_slice_inner(d, ex, id::Int)
120+
$(_generate_hessian_slice_inner())
121+
return error("Invalid chunk size: $id")
122+
end
123+
124+
function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
119125
fill!(d.output_ϵ, 0.0)
120126
output_ϵ = _reinterpret_unsafe(T, d.output_ϵ)
121127
subexpr_forward_values_ϵ =

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
140140
max_expr_length = max(max_expr_length, length(d.constraints[end].nodes))
141141
max_chunk = max(max_chunk, size(d.constraints[end].seed_matrix, 2))
142142
end
143-
# 10 is hardcoded upper bound to avoid excess memory allocation
144-
max_chunk = min(max_chunk, 10)
143+
max_chunk = min(max_chunk, MAX_CHUNK)
145144
max_expr_with_sub_length = max(max_expr_with_sub_length, max_expr_length)
146145
if d.want_hess || want_hess_storage
147146
d.input_ϵ = zeros(max_chunk * N)

test/Nonlinear/ReverseAD.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,21 +150,13 @@ function test_objective_quadratic_multivariate_subexpressions()
150150
MOI.eval_hessian_objective(evaluator, H, val)
151151
@test H == [2.0, 2.0, 1.0]
152152
@test evaluator.backend.max_chunk == 2
153-
# The call of `_eval_hessian_inner` from `_eval_hessian` needs dynamic dispatch for `Val(chunk)` so it allocates.
154-
# We call directly `_eval_hessian_inner` to check that the rest does not allocates.
155-
@test 0 == @allocated MOI.Nonlinear.ReverseAD._eval_hessian_inner(
156-
evaluator.backend,
157-
evaluator.backend.objective,
158-
H,
159-
1.0,
160-
0,
161-
Val(2),
162-
)
153+
@test 0 == @allocated MOI.eval_hessian_objective(evaluator, H, val)
163154
@test MOI.hessian_lagrangian_structure(evaluator) ==
164155
[(1, 1), (2, 2), (2, 1)]
165156
H = [NaN, NaN, NaN]
166157
μ = Float64[]
167158
MOI.eval_hessian_lagrangian(evaluator, H, val, 1.5, μ)
159+
@test 0 == @allocated MOI.eval_hessian_lagrangian(evaluator, H, val, 1.5, μ)
168160
@test H == 1.5 .* [2.0, 2.0, 1.0]
169161
v = [0.3, 0.4]
170162
hv = [NaN, NaN]
@@ -1393,6 +1385,20 @@ function test_eval_user_defined_operator_type_mismatch()
13931385
return
13941386
end
13951387

1388+
function test_generate_hessian_slice_inner()
1389+
# Test that it evaluates without error. The code contents are tested
1390+
# elsewhere.
1391+
MOI.Nonlinear.ReverseAD._generate_hessian_slice_inner()
1392+
d = ex = nothing # These arguments are untyped and not needed for this test
1393+
for id in [0, MOI.Nonlinear.ReverseAD.MAX_CHUNK + 1]
1394+
@test_throws(
1395+
ErrorException("Invalid chunk size: $id"),
1396+
MOI.Nonlinear.ReverseAD._hessian_slice_inner(d, ex, id),
1397+
)
1398+
end
1399+
return
1400+
end
1401+
13961402
end # module
13971403

13981404
TestReverseAD.runtests()

0 commit comments

Comments
 (0)