@@ -50,61 +50,19 @@ function _eval_hessian_inner(
50
50
@assert length (ex. hess_I) == 0
51
51
return 0
52
52
end
53
- T = ForwardDiff. Partials{CHUNK,Float64} # This is our element type.
54
53
Coloring. prepare_seed_matrix! (ex. seed_matrix, ex. rinfo)
55
- local_to_global_idx = ex. rinfo. local_indices
56
- input_ϵ_raw, output_ϵ_raw = d. input_ϵ, d. output_ϵ
57
- input_ϵ = _reinterpret_unsafe (T, input_ϵ_raw)
58
- output_ϵ = _reinterpret_unsafe (T, output_ϵ_raw)
59
54
# Compute hessian-vector products
60
55
num_products = size (ex. seed_matrix, 2 ) # number of hessian-vector products
61
56
num_chunks = div (num_products, CHUNK)
62
- @assert size (ex. seed_matrix, 1 ) == length (local_to_global_idx)
63
- for k in 1 : CHUNK: (CHUNK* num_chunks)
64
- for r in 1 : length (local_to_global_idx)
65
- # set up directional derivatives
66
- @inbounds idx = local_to_global_idx[r]
67
- # load up ex.seed_matrix[r,k,k+1,...,k+CHUNK-1] into input_ϵ
68
- for s in 1 : CHUNK
69
- input_ϵ_raw[(idx- 1 )* CHUNK+ s] = ex. seed_matrix[r, k+ s- 1 ]
70
- end
71
- @inbounds output_ϵ[idx] = zero (T)
72
- end
73
- _hessian_slice_inner (d, ex, input_ϵ, output_ϵ, T)
74
- # collect directional derivatives
75
- for r in 1 : length (local_to_global_idx)
76
- idx = local_to_global_idx[r]
77
- # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+CHUNK-1]
78
- for s in 1 : CHUNK
79
- ex. seed_matrix[r, k+ s- 1 ] = output_ϵ_raw[(idx- 1 )* CHUNK+ s]
80
- end
81
- @inbounds input_ϵ[idx] = zero (T)
82
- end
57
+ @assert size (ex. seed_matrix, 1 ) == length (ex. rinfo. local_indices)
58
+ for offset in 1 : CHUNK: (CHUNK* num_chunks)
59
+ _eval_hessian_chunk (d, ex, offset, CHUNK, Val (CHUNK))
83
60
end
84
61
# leftover chunk
85
62
remaining = num_products - CHUNK * num_chunks
86
63
if remaining > 0
87
- k = CHUNK * num_chunks + 1
88
- for r in 1 : length (local_to_global_idx)
89
- # set up directional derivatives
90
- @inbounds idx = local_to_global_idx[r]
91
- # load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
92
- for s in 1 : remaining
93
- # leave junk in the unused components
94
- input_ϵ_raw[(idx- 1 )* CHUNK+ s] = ex. seed_matrix[r, k+ s- 1 ]
95
- end
96
- @inbounds output_ϵ[idx] = zero (T)
97
- end
98
- _hessian_slice_inner (d, ex, input_ϵ, output_ϵ, T)
99
- # collect directional derivatives
100
- for r in 1 : length (local_to_global_idx)
101
- idx = local_to_global_idx[r]
102
- # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
103
- for s in 1 : remaining
104
- ex. seed_matrix[r, k+ s- 1 ] = output_ϵ_raw[(idx- 1 )* CHUNK+ s]
105
- end
106
- @inbounds input_ϵ[idx] = zero (T)
107
- end
64
+ offset = CHUNK * num_chunks + 1
65
+ _eval_hessian_chunk (d, ex, offset, remaining, Val (CHUNK))
108
66
end
109
67
want, got = nzcount + length (ex. hess_I), length (H)
110
68
if want > got
@@ -127,7 +85,40 @@ function _eval_hessian_inner(
127
85
return length (ex. hess_I)
128
86
end
129
87
130
- function _hessian_slice_inner (d, ex, input_ϵ, output_ϵ, :: Type{T} ) where {T}
88
+ function _eval_hessian_chunk (
89
+ d:: NLPEvaluator ,
90
+ ex:: _FunctionStorage ,
91
+ offset:: Int ,
92
+ chunk:: Int ,
93
+ :: Val{CHUNK} ,
94
+ ) where {CHUNK}
95
+ for r in eachindex (ex. rinfo. local_indices)
96
+ # set up directional derivatives
97
+ @inbounds idx = ex. rinfo. local_indices[r]
98
+ # load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
99
+ for s in 1 : chunk
100
+ # If `chunk < CHUNK`, leaves junk in the unused components
101
+ d. input_ϵ[(idx- 1 )* CHUNK+ s] = ex. seed_matrix[r, offset+ s- 1 ]
102
+ end
103
+ end
104
+ _hessian_slice_inner (d, ex, Val (CHUNK))
105
+ fill! (d. input_ϵ, 0.0 )
106
+ # collect directional derivatives
107
+ for r in eachindex (ex. rinfo. local_indices)
108
+ @inbounds idx = ex. rinfo. local_indices[r]
109
+ # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
110
+ for s in 1 : chunk
111
+ ex. seed_matrix[r, offset+ s- 1 ] = d. output_ϵ[(idx- 1 )* CHUNK+ s]
112
+ end
113
+ end
114
+ return
115
+ end
116
+
117
+ function _hessian_slice_inner (d, ex, :: Val{CHUNK} ) where {CHUNK}
118
+ T = ForwardDiff. Partials{CHUNK,Float64} # This is our element type.
119
+ input_ϵ = _reinterpret_unsafe (T, d. input_ϵ)
120
+ fill! (d. output_ϵ, 0.0 )
121
+ output_ϵ = _reinterpret_unsafe (T, d. output_ϵ)
131
122
subexpr_forward_values_ϵ =
132
123
_reinterpret_unsafe (T, d. subexpression_forward_values_ϵ)
133
124
for i in ex. dependent_subexpressions
0 commit comments