Skip to content

Commit dadc879

Browse files
authored
[Nonlinear.ReverseAD] simplify _eval_hessian_inner (#2730)
1 parent 355a039 commit dadc879

File tree

1 file changed

+39
-48
lines changed

1 file changed

+39
-48
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -50,61 +50,19 @@ function _eval_hessian_inner(
5050
@assert length(ex.hess_I) == 0
5151
return 0
5252
end
53-
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
5453
Coloring.prepare_seed_matrix!(ex.seed_matrix, ex.rinfo)
55-
local_to_global_idx = ex.rinfo.local_indices
56-
input_ϵ_raw, output_ϵ_raw = d.input_ϵ, d.output_ϵ
57-
input_ϵ = _reinterpret_unsafe(T, input_ϵ_raw)
58-
output_ϵ = _reinterpret_unsafe(T, output_ϵ_raw)
5954
# Compute hessian-vector products
6055
num_products = size(ex.seed_matrix, 2) # number of hessian-vector products
6156
num_chunks = div(num_products, CHUNK)
62-
@assert size(ex.seed_matrix, 1) == length(local_to_global_idx)
63-
for k in 1:CHUNK:(CHUNK*num_chunks)
64-
for r in 1:length(local_to_global_idx)
65-
# set up directional derivatives
66-
@inbounds idx = local_to_global_idx[r]
67-
# load up ex.seed_matrix[r,k,k+1,...,k+CHUNK-1] into input_ϵ
68-
for s in 1:CHUNK
69-
input_ϵ_raw[(idx-1)*CHUNK+s] = ex.seed_matrix[r, k+s-1]
70-
end
71-
@inbounds output_ϵ[idx] = zero(T)
72-
end
73-
_hessian_slice_inner(d, ex, input_ϵ, output_ϵ, T)
74-
# collect directional derivatives
75-
for r in 1:length(local_to_global_idx)
76-
idx = local_to_global_idx[r]
77-
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+CHUNK-1]
78-
for s in 1:CHUNK
79-
ex.seed_matrix[r, k+s-1] = output_ϵ_raw[(idx-1)*CHUNK+s]
80-
end
81-
@inbounds input_ϵ[idx] = zero(T)
82-
end
57+
@assert size(ex.seed_matrix, 1) == length(ex.rinfo.local_indices)
58+
for offset in 1:CHUNK:(CHUNK*num_chunks)
59+
_eval_hessian_chunk(d, ex, offset, CHUNK, Val(CHUNK))
8360
end
8461
# leftover chunk
8562
remaining = num_products - CHUNK * num_chunks
8663
if remaining > 0
87-
k = CHUNK * num_chunks + 1
88-
for r in 1:length(local_to_global_idx)
89-
# set up directional derivatives
90-
@inbounds idx = local_to_global_idx[r]
91-
# load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
92-
for s in 1:remaining
93-
# leave junk in the unused components
94-
input_ϵ_raw[(idx-1)*CHUNK+s] = ex.seed_matrix[r, k+s-1]
95-
end
96-
@inbounds output_ϵ[idx] = zero(T)
97-
end
98-
_hessian_slice_inner(d, ex, input_ϵ, output_ϵ, T)
99-
# collect directional derivatives
100-
for r in 1:length(local_to_global_idx)
101-
idx = local_to_global_idx[r]
102-
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
103-
for s in 1:remaining
104-
ex.seed_matrix[r, k+s-1] = output_ϵ_raw[(idx-1)*CHUNK+s]
105-
end
106-
@inbounds input_ϵ[idx] = zero(T)
107-
end
64+
offset = CHUNK * num_chunks + 1
65+
_eval_hessian_chunk(d, ex, offset, remaining, Val(CHUNK))
10866
end
10967
want, got = nzcount + length(ex.hess_I), length(H)
11068
if want > got
@@ -127,7 +85,40 @@ function _eval_hessian_inner(
12785
return length(ex.hess_I)
12886
end
12987

130-
function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
88+
function _eval_hessian_chunk(
89+
d::NLPEvaluator,
90+
ex::_FunctionStorage,
91+
offset::Int,
92+
chunk::Int,
93+
::Val{CHUNK},
94+
) where {CHUNK}
95+
for r in eachindex(ex.rinfo.local_indices)
96+
# set up directional derivatives
97+
@inbounds idx = ex.rinfo.local_indices[r]
98+
# load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
99+
for s in 1:chunk
100+
# If `chunk < CHUNK`, leaves junk in the unused components
101+
d.input_ϵ[(idx-1)*CHUNK+s] = ex.seed_matrix[r, offset+s-1]
102+
end
103+
end
104+
_hessian_slice_inner(d, ex, Val(CHUNK))
105+
fill!(d.input_ϵ, 0.0)
106+
# collect directional derivatives
107+
for r in eachindex(ex.rinfo.local_indices)
108+
@inbounds idx = ex.rinfo.local_indices[r]
109+
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
110+
for s in 1:chunk
111+
ex.seed_matrix[r, offset+s-1] = d.output_ϵ[(idx-1)*CHUNK+s]
112+
end
113+
end
114+
return
115+
end
116+
117+
function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
118+
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
119+
input_ϵ = _reinterpret_unsafe(T, d.input_ϵ)
120+
fill!(d.output_ϵ, 0.0)
121+
output_ϵ = _reinterpret_unsafe(T, d.output_ϵ)
131122
subexpr_forward_values_ϵ =
132123
_reinterpret_unsafe(T, d.subexpression_forward_values_ϵ)
133124
for i in ex.dependent_subexpressions

0 commit comments

Comments
 (0)