Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit da36df6

Browse files
Merge pull request #113 from avik-pal/ap/polyester_mode
Add Polyester ForwardDiff support
2 parents 45f8d73 + b8cc83d commit da36df6

File tree

9 files changed

+102
-33
lines changed

9 files changed

+102
-33
lines changed

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.0.4"
4+
version = "1.1.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -17,8 +17,14 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1717
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1818
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1919

20+
[extensions]
21+
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
22+
23+
[weakdeps]
24+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
25+
2026
[compat]
21-
ADTypes = "0.2"
27+
ADTypes = "0.2.6"
2228
ArrayInterface = "7"
2329
ConcreteStructs = "0.2"
2430
DiffEqBase = "6.126"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module SimpleNonlinearSolvePolyesterForwardDiffExt
2+
3+
using SimpleNonlinearSolve, PolyesterForwardDiff
4+
5+
@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true
6+
7+
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f!::F, y, J, x,
8+
chunksize) where {F}
9+
PolyesterForwardDiff.threaded_jacobian!(f!, y, J, x, chunksize)
10+
return J
11+
end
12+
13+
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x,
14+
chunksize) where {F}
15+
PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize)
16+
return J
17+
end
18+
19+
end

src/SimpleNonlinearSolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
99
import DiffEqBase: AbstractNonlinearTerminationMode,
1010
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
1111
NonlinearSafeTerminationReturnCode, get_termination_mode,
12-
NONLINEARSOLVE_DEFAULT_NORM
12+
NONLINEARSOLVE_DEFAULT_NORM, _get_tolerance
1313
using FiniteDiff, ForwardDiff
1414
import ForwardDiff: Dual
1515
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
@@ -23,6 +23,8 @@ abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorith
2323
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
2424
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
2525

26+
@inline __is_extension_loaded(::Val) = false
27+
2628
include("utils.jl")
2729

2830
## Nonlinear Solvers

src/nlsolve/halley.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ A low-overhead implementation of Halley's Method.
1212
1313
### Keyword Arguments
1414
15-
- `autodiff`: determines the backend used for the Hessian. Defaults to
16-
`AutoForwardDiff()`. Valid choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
15+
- `autodiff`: determines the backend used for the Hessian. Defaults to `nothing`. Valid
16+
choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
1717
1818
!!! warning
1919
2020
Inplace Problems are currently not supported by this method.
2121
"""
2222
@kwdef @concrete struct SimpleHalley <: AbstractNewtonAlgorithm
23-
autodiff = AutoForwardDiff()
23+
autodiff = nothing
2424
end
2525

2626
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
@@ -33,6 +33,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
3333
fx = _get_fx(prob, x)
3434
T = eltype(x)
3535

36+
autodiff = __get_concrete_autodiff(prob, alg.autodiff; polyester = Val(false))
3637
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
3738
termination_condition)
3839

@@ -50,17 +51,20 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
5051

5152
for i in 1:maxiters
5253
# Hessian Computation is unfortunately type unstable
53-
fx, dfx, d2fx = compute_jacobian_and_hessian(alg.autodiff, prob, fx, x)
54+
fx, dfx, d2fx = compute_jacobian_and_hessian(autodiff, prob, fx, x)
5455
setindex_trait(x) === CannotSetindex() && (A = dfx)
5556

56-
aᵢ = dfx \ _vec(fx)
57+
# Factorize Once and Reuse
58+
dfx_fact = factorize(dfx)
59+
60+
aᵢ = dfx_fact \ _vec(fx)
5761
A_ = _vec(A)
5862
@bb A_ = d2fx × aᵢ
5963
A = _restructure(A, A_)
6064

6165
@bb Aaᵢ = A × aᵢ
6266
@bb A .*= -1
63-
bᵢ = dfx \ Aaᵢ
67+
bᵢ = dfx_fact \ Aaᵢ
6468

6569
cᵢ_ = _vec(cᵢ)
6670
@bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ))

src/nlsolve/raphson.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
SimpleNewtonRaphson(autodiff)
3-
SimpleNewtonRaphson(; autodiff = AutoForwardDiff())
3+
SimpleNewtonRaphson(; autodiff = nothing)
44
55
A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar
66
and static array problems.
@@ -14,10 +14,11 @@ and static array problems.
1414
### Keyword Arguments
1515
1616
- `autodiff`: determines the backend used for the Jacobian. Defaults to
17-
`AutoForwardDiff()`. Valid choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
17+
`nothing`. Valid choices are `AutoPolyesterForwardDiff()`, `AutoForwardDiff()` or
18+
`AutoFiniteDiff()`.
1819
"""
1920
@kwdef @concrete struct SimpleNewtonRaphson <: AbstractNewtonAlgorithm
20-
autodiff = AutoForwardDiff()
21+
autodiff = nothing
2122
end
2223

2324
const SimpleGaussNewton = SimpleNewtonRaphson
@@ -27,14 +28,15 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresPr
2728
maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...)
2829
x = __maybe_unaliased(prob.u0, alias_u0)
2930
fx = _get_fx(prob, x)
31+
autodiff = __get_concrete_autodiff(prob, alg.autodiff)
3032
@bb xo = copy(x)
31-
J, jac_cache = jacobian_cache(alg.autodiff, prob.f, fx, x, prob.p)
33+
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)
3234

3335
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
3436
termination_condition)
3537

3638
for i in 1:maxiters
37-
fx, dfx = value_and_jacobian(alg.autodiff, prob.f, fx, x, prob.p, jac_cache; J)
39+
fx, dfx = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
3840

3941
if i == 1
4042
iszero(fx) && build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

src/nlsolve/trustRegion.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ scalar and static array problems.
1010
### Keyword Arguments
1111
1212
- `autodiff`: determines the backend used for the Jacobian. Defaults to
13-
`AutoForwardDiff()`. Valid choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
13+
`nothing`. Valid choices are `AutoPolyesterForwardDiff()`, `AutoForwardDiff()` or
14+
`AutoFiniteDiff()`.
1415
- `max_trust_radius`: the maximum radius of the trust region. Defaults to
1516
`max(norm(f(u0)), maximum(u0) - minimum(u0))`.
1617
- `initial_trust_radius`: the initial trust region radius. Defaults to
@@ -37,7 +38,7 @@ scalar and static array problems.
3738
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
3839
"""
3940
@kwdef @concrete struct SimpleTrustRegion <: AbstractNewtonAlgorithm
40-
autodiff = AutoForwardDiff()
41+
autodiff = nothing
4142
max_trust_radius = 0.0
4243
initial_trust_radius = 0.0
4344
step_threshold = 0.0001
@@ -61,11 +62,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
6162
t₁ = T(alg.shrink_factor)
6263
t₂ = T(alg.expand_factor)
6364
max_shrink_times = alg.max_shrink_times
65+
autodiff = __get_concrete_autodiff(prob, alg.autodiff)
6466

6567
fx = _get_fx(prob, x)
6668
@bb xo = copy(x)
67-
J, jac_cache = jacobian_cache(alg.autodiff, prob.f, fx, x, prob.p)
68-
fx, ∇f = value_and_jacobian(alg.autodiff, prob.f, fx, x, prob.p, jac_cache; J)
69+
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)
70+
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
6971

7072
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
7173
termination_condition)
@@ -116,7 +118,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
116118
# Take the step.
117119
@bb @. xo = x
118120

119-
fx, ∇f = value_and_jacobian(alg.autodiff, prob.f, fx, x, prob.p, jac_cache; J)
121+
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
120122

121123
# Update the trust region radius.
122124
(r > η₃) && (norm(δ) Δ) &&= min(t₂ * Δ, Δₘₐₓ))

src/utils.jl

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,6 @@ Return the maximum of `a` and `b` if `x1 > x0`, otherwise return the minimum.
2626
"""
2727
__max_tdir(a, b, x0, x1) = ifelse(x1 > x0, max(a, b), min(a, b))
2828

29-
__cvt_real(::Type{T}, ::Nothing) where {T} = nothing
30-
__cvt_real(::Type{T}, x) where {T} = real(T(x))
31-
32-
_get_tolerance(η, ::Type{T}) where {T} = __cvt_real(T, η)
33-
function _get_tolerance(::Nothing, ::Type{T}) where {T}
34-
η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)
35-
return _get_tolerance(η, T)
36-
end
37-
3829
__standard_tag(::Nothing, x) = ForwardDiff.Tag(SimpleNonlinearSolveTag(), eltype(x))
3930
__standard_tag(tag::ForwardDiff.Tag, _) = tag
4031
__standard_tag(tag, x) = ForwardDiff.Tag(tag, eltype(x))
@@ -60,6 +51,12 @@ function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!, y, x) where {CS}
6051
return ForwardDiff.JacobianConfig(f!, y, x, ck, tag)
6152
end
6253

54+
function __get_jacobian_config(ad::AutoPolyesterForwardDiff{CS}, args...) where {CS}
55+
x = last(args)
56+
return (CS === nothing || CS 0) ? __pick_forwarddiff_chunk(x) :
57+
ForwardDiff.Chunk{CS}()
58+
end
59+
6360
"""
6461
value_and_jacobian(ad, f, y, x, p, cache; J = nothing)
6562
@@ -81,6 +78,9 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
8178
FiniteDiff.finite_difference_jacobian!(J, _f, x, cache)
8279
_f(y, x)
8380
return y, J
81+
elseif ad isa AutoPolyesterForwardDiff
82+
__polyester_forwarddiff_jacobian!(_f, y, J, x, cache)
83+
return y, J
8484
else
8585
throw(ArgumentError("Unsupported AD method: $(ad)"))
8686
end
@@ -100,19 +100,30 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
100100
elseif ad isa AutoFiniteDiff
101101
J_fd = FiniteDiff.finite_difference_jacobian(_f, x, cache)
102102
return _f(x), J_fd
103+
elseif ad isa AutoPolyesterForwardDiff
104+
__polyester_forwarddiff_jacobian!(_f, J, x, cache)
105+
return _f(x), J
103106
else
104107
throw(ArgumentError("Unsupported AD method: $(ad)"))
105108
end
106109
end
107110
end
108111

112+
# Declare functions
113+
function __polyester_forwarddiff_jacobian! end
114+
109115
function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where {F}
110116
if DiffEqBase.has_jac(f)
111117
return f(x, p), f.jac(x, p)
112118
elseif ad isa AutoForwardDiff
113119
T = typeof(__standard_tag(ad.tag, x))
114120
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
115121
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
122+
elseif ad isa AutoPolyesterForwardDiff
123+
# Just use ForwardDiff
124+
T = typeof(__standard_tag(nothing, x))
125+
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
126+
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
116127
elseif ad isa AutoFiniteDiff
117128
_f = Base.Fix2(f, p)
118129
return _f(x), FiniteDiff.finite_difference_derivative(_f, x, ad.fdtype)
@@ -132,7 +143,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
132143
J = similar(y, length(y), length(x))
133144
if DiffEqBase.has_jac(f)
134145
return J, nothing
135-
elseif ad isa AutoForwardDiff
146+
elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff
136147
return J, __get_jacobian_config(ad, _f, y, x)
137148
elseif ad isa AutoFiniteDiff
138149
return J, FiniteDiff.JacobianCache(copy(x), copy(y), copy(y), ad.fdtype)
@@ -146,6 +157,10 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
146157
elseif ad isa AutoForwardDiff
147158
J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing
148159
return J, __get_jacobian_config(ad, _f, x)
160+
elseif ad isa AutoPolyesterForwardDiff
161+
@assert ArrayInterface.can_setindex(x) "PolyesterForwardDiff requires mutable inputs. Use AutoForwardDiff instead."
162+
J = similar(y, length(y), length(x))
163+
return J, __get_jacobian_config(ad, _f, x)
149164
elseif ad isa AutoFiniteDiff
150165
return nothing, FiniteDiff.JacobianCache(copy(x), copy(y), copy(y), ad.fdtype)
151166
else
@@ -350,3 +365,19 @@ end
350365
(alias || !ArrayInterface.can_setindex(typeof(x))) && return x
351366
return deepcopy(x)
352367
end
368+
369+
# Decide which AD backend to use
370+
@inline __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType; kwargs...) = ad
371+
@inline function __get_concrete_autodiff(prob, ::Nothing; polyester::Val{P} = Val(true),
372+
kwargs...) where {P}
373+
if ForwardDiff.can_dual(eltype(prob.u0))
374+
if P && __is_extension_loaded(Val(:PolyesterForwardDiff)) &&
375+
!(prob.u0 isa Number) && ArrayInterface.can_setindex(prob.u0)
376+
return AutoPolyesterForwardDiff()
377+
else
378+
return AutoForwardDiff()
379+
end
380+
else
381+
return AutoFiniteDiff()
382+
end
383+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
77
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
88
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
99
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
10+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

test/basictests.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using AllocCheck, BenchmarkTools, LinearSolve, SimpleNonlinearSolve, StaticArrays, Random,
22
LinearAlgebra, Test, ForwardDiff, DiffEqBase
3+
import PolyesterForwardDiff
34

45
_nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x))
56

@@ -29,20 +30,21 @@ const TERMINATION_CONDITIONS = [
2930
@testset "$(alg)" for alg in (SimpleNewtonRaphson, SimpleTrustRegion)
3031
# Eval else the alg is type unstable
3132
@eval begin
32-
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = AutoForwardDiff())
33+
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = nothing)
3334
prob = NonlinearProblem{false}(f, u0, p)
3435
return solve(prob, $(alg)(; autodiff), abstol = 1e-9)
3536
end
3637

37-
function benchmark_nlsolve_iip(f, u0, p = 2.0; autodiff = AutoForwardDiff())
38+
function benchmark_nlsolve_iip(f, u0, p = 2.0; autodiff = nothing)
3839
prob = NonlinearProblem{true}(f, u0, p)
3940
return solve(prob, $(alg)(; autodiff), abstol = 1e-9)
4041
end
4142
end
4243

4344
@testset "AutoDiff: $(_nameof(autodiff))" for autodiff in (AutoFiniteDiff(),
44-
AutoForwardDiff())
45+
AutoForwardDiff(), AutoPolyesterForwardDiff())
4546
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
47+
u0 isa SVector && autodiff isa AutoPolyesterForwardDiff && continue
4648
sol = benchmark_nlsolve_oop(quadratic_f, u0; autodiff)
4749
@test SciMLBase.successful_retcode(sol)
4850
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
@@ -103,7 +105,7 @@ end
103105
# --- SimpleHalley tests ---
104106

105107
@testset "SimpleHalley" begin
106-
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = AutoForwardDiff())
108+
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = nothing)
107109
prob = NonlinearProblem{false}(f, u0, p)
108110
return solve(prob, SimpleHalley(; autodiff), abstol = 1e-9)
109111
end

0 commit comments

Comments
 (0)