Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 0b25978

Browse files
Merge pull request #116 from SciML/ap/ls
Add Line Search to (L)Broyden
2 parents ade88a2 + 469afbf commit 0b25978

18 files changed

+301
-104
lines changed

Project.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.2.1"
4+
version = "1.3.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
11+
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1112
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1213
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -17,17 +18,20 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1718
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1819
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1920

20-
[extensions]
21-
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
22-
2321
[weakdeps]
2422
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
23+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
24+
25+
[extensions]
26+
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
27+
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
2528

2629
[compat]
2730
ADTypes = "0.2.6"
2831
ArrayInterface = "7"
2932
ConcreteStructs = "0.2"
3033
DiffEqBase = "6.126"
34+
FastClosures = "0.3"
3135
FiniteDiff = "2"
3236
ForwardDiff = "0.10.3"
3337
LinearAlgebra = "1.9"
@@ -36,4 +40,5 @@ PrecompileTools = "1"
3640
Reexport = "1"
3741
SciMLBase = "2.7"
3842
StaticArraysCore = "1.4"
43+
StaticArrays = "1"
3944
julia = "1.9"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module SimpleNonlinearSolveStaticArraysExt
2+
3+
using SimpleNonlinearSolve
4+
5+
@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true
6+
7+
end

src/SimpleNonlinearSolve.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@ module SimpleNonlinearSolve
33
import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
44

55
@recompile_invalidations begin
6-
using ADTypes,
7-
ArrayInterface, ConcreteStructs, DiffEqBase, Reexport, LinearAlgebra, SciMLBase
6+
using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff,
7+
ForwardDiff, Reexport, LinearAlgebra, SciMLBase
88

99
import DiffEqBase: AbstractNonlinearTerminationMode,
1010
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
1111
NonlinearSafeTerminationReturnCode, get_termination_mode,
12-
NONLINEARSOLVE_DEFAULT_NORM, _get_tolerance
13-
using FiniteDiff, ForwardDiff
12+
NONLINEARSOLVE_DEFAULT_NORM
1413
import ForwardDiff: Dual
1514
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
16-
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace
15+
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
1716
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size
1817
end
1918

@@ -26,6 +25,7 @@ abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm e
2625
@inline __is_extension_loaded(::Val) = false
2726

2827
include("utils.jl")
28+
include("linesearch.jl")
2929

3030
## Nonlinear Solvers
3131
include("nlsolve/raphson.jl")

src/ad.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray}
77
sol.original)
88
end
99

10-
# Handle Ambiguities
1110
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
1211
@eval begin
1312
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,

src/bracketing/bisection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
2626
left, right = prob.tspan
2727
fl, fr = f(left), f(right)
2828

29-
abstol = _get_tolerance(abstol,
29+
abstol = __get_tolerance(nothing, abstol,
3030
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
3131

3232
if iszero(fl)

src/bracketing/brent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
1313
fl, fr = f(left), f(right)
1414
ϵ = eps(convert(typeof(fl), 1))
1515

16-
abstol = _get_tolerance(abstol,
16+
abstol = __get_tolerance(nothing, abstol,
1717
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1818

1919
if iszero(fl)

src/bracketing/falsi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = _get_tolerance(abstol,
15+
abstol = __get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/bracketing/itp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
5858
left, right = prob.tspan
5959
fl, fr = f(left), f(right)
6060

61-
abstol = _get_tolerance(abstol,
61+
abstol = __get_tolerance(nothing, abstol,
6262
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
6363

6464
if iszero(fl)

src/bracketing/ridder.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = _get_tolerance(abstol,
15+
abstol = __get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/linesearch.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# This is a copy of the version in NonlinearSolve.jl. Temporarily kept here till we move
2+
# line searches into a dedicated package.
3+
@kwdef @concrete struct LiFukushimaLineSearch
4+
lambda_0 = 1
5+
beta = 0.5
6+
sigma_1 = 0.001
7+
sigma_2 = 0.001
8+
eta = 0.1
9+
rho = 0.1
10+
nan_maxiters = missing
11+
maxiters::Int = 100
12+
end
13+
14+
@concrete mutable struct LiFukushimaLineSearchCache{T <: Union{Nothing, Int}}
15+
ϕ
16+
λ₀
17+
β
18+
σ₁
19+
σ₂
20+
η
21+
ρ
22+
α
23+
nan_maxiters::T
24+
maxiters::Int
25+
end
26+
27+
@concrete struct StaticLiFukushimaLineSearchCache
28+
f
29+
p
30+
λ₀
31+
β
32+
σ₁
33+
σ₂
34+
η
35+
ρ
36+
maxiters::Int
37+
end
38+
39+
(alg::LiFukushimaLineSearch)(prob, fu, u) = __generic_init(alg, prob, fu, u)
40+
function (alg::LiFukushimaLineSearch)(prob, fu::Union{Number, SArray},
41+
u::Union{Number, SArray})
42+
(alg.nan_maxiters === missing || alg.nan_maxiters === nothing) &&
43+
return __static_init(alg, prob, fu, u)
44+
@warn "`LiFukushimaLineSearch` with NaN checking is not non-allocating" maxlog=1
45+
return __generic_init(alg, prob, fu, u)
46+
end
47+
48+
function __generic_init(alg::LiFukushimaLineSearch, prob, fu, u)
49+
@bb u_cache = similar(u)
50+
@bb fu_cache = similar(fu)
51+
T = promote_type(eltype(fu), eltype(u))
52+
53+
ϕ = @closure (u, δu, α) -> begin
54+
@bb @. u_cache = u + α * δu
55+
return NONLINEARSOLVE_DEFAULT_NORM(__eval_f(prob, fu_cache, u_cache))
56+
end
57+
58+
nan_maxiters = ifelse(alg.nan_maxiters === missing, 5, alg.nan_maxiters)
59+
60+
return LiFukushimaLineSearchCache(ϕ, T(alg.lambda_0), T(alg.beta), T(alg.sigma_1),
61+
T(alg.sigma_2), T(alg.eta), T(alg.rho), T(true), nan_maxiters, alg.maxiters)
62+
end
63+
64+
function __static_init(alg::LiFukushimaLineSearch, prob, fu, u)
65+
T = promote_type(eltype(fu), eltype(u))
66+
return StaticLiFukushimaLineSearchCache(prob.f, prob.p, T(alg.lambda_0), T(alg.beta),
67+
T(alg.sigma_1), T(alg.sigma_2), T(alg.eta), T(alg.rho), alg.maxiters)
68+
end
69+
70+
function (cache::LiFukushimaLineSearchCache)(u, δu)
71+
T = promote_type(eltype(u), eltype(δu))
72+
ϕ = @closure α -> cache.ϕ(u, δu, α)
73+
fx_norm = ϕ(T(0))
74+
75+
# Non-Blocking exit if the norm is NaN or Inf
76+
DiffEqBase.NAN_CHECK(fx_norm) && return cache.α
77+
78+
# Early Terminate based on Eq. 2.7
79+
du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu)
80+
fxλ_norm = ϕ(cache.α)
81+
fxλ_norm cache.ρ * fx_norm - cache.σ₂ * du_norm^2 && return cache.α
82+
83+
λ₂, λ₁ = cache.λ₀, cache.λ₀
84+
fxλp_norm = ϕ(λ₂)
85+
86+
if cache.nan_maxiters !== nothing
87+
if DiffEqBase.NAN_CHECK(fxλp_norm)
88+
nan_converged = false
89+
for _ in 1:(cache.nan_maxiters)
90+
λ₁, λ₂ = λ₂, cache.β * λ₂
91+
fxλp_norm = ϕ(λ₂)
92+
nan_converged = DiffEqBase.NAN_CHECK(fxλp_norm)::Bool
93+
nan_converged && break
94+
end
95+
nan_converged || return cache.α
96+
end
97+
end
98+
99+
for i in 1:(cache.maxiters)
100+
fxλp_norm = ϕ(λ₂)
101+
converged = fxλp_norm (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
102+
converged && return λ₂
103+
λ₁, λ₂ = λ₂, cache.β * λ₂
104+
end
105+
106+
return cache.α
107+
end
108+
109+
function (cache::StaticLiFukushimaLineSearchCache)(u, δu)
110+
T = promote_type(eltype(u), eltype(δu))
111+
112+
# Early Terminate based on Eq. 2.7
113+
fx_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u, cache.p))
114+
du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu)
115+
fxλ_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u .+ δu, cache.p))
116+
fxλ_norm cache.ρ * fx_norm - cache.σ₂ * du_norm^2 && return T(true)
117+
118+
λ₂, λ₁ = cache.λ₀, cache.λ₀
119+
120+
for i in 1:(cache.maxiters)
121+
fxλp_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u .+ λ₂ .* δu, cache.p))
122+
converged = fxλp_norm (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
123+
converged && return λ₂
124+
λ₁, λ₂ = λ₂, cache.β * λ₂
125+
end
126+
127+
return T(true)
128+
end

src/nlsolve/broyden.jl

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,54 @@
11
"""
2-
SimpleBroyden()
2+
SimpleBroyden(; linesearch = Val(false), alpha = nothing)
33
44
A low-overhead implementation of Broyden. This method is non-allocating on scalar
55
and static array problems.
6+
7+
### Keyword Arguments
8+
9+
- `linesearch`: If `linesearch` is `Val(true)`, then we use the `LiFukushimaLineSearch`
10+
[1] line search else no line search is used. For advanced customization of the line
11+
search, use the [`Broyden`](@ref) algorithm in `NonlinearSolve.jl`.
12+
- `alpha`: Scale the initial jacobian initialization with `alpha`. If it is `nothing`, we
13+
will compute the scaling using `2 * norm(fu) / max(norm(u), true)`.
14+
15+
### References
16+
17+
[1] Li, Dong-Hui, and Masao Fukushima. "A derivative-free line search and global convergence
18+
of Broyden-like method for nonlinear equations." Optimization methods and software 13.3
19+
(2000): 181-201.
620
"""
7-
struct SimpleBroyden <: AbstractSimpleNonlinearSolveAlgorithm end
21+
@concrete struct SimpleBroyden{linesearch} <: AbstractSimpleNonlinearSolveAlgorithm
22+
alpha
23+
end
24+
25+
function SimpleBroyden(; linesearch = Val(false), alpha = nothing)
26+
return SimpleBroyden{_unwrap_val(linesearch)}(alpha)
27+
end
28+
29+
__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)
830

931
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
1032
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
1133
termination_condition = nothing, kwargs...)
1234
x = __maybe_unaliased(prob.u0, alias_u0)
1335
fx = _get_fx(prob, x)
36+
T = promote_type(eltype(x), eltype(fx))
1437

1538
@bb xo = copy(x)
1639
@bb δx = copy(x)
1740
@bb δf = copy(fx)
1841
@bb fprev = copy(fx)
1942

20-
J⁻¹ = __init_identity_jacobian(fx, x)
43+
if alg.alpha === nothing
44+
fx_norm = NONLINEARSOLVE_DEFAULT_NORM(fx)
45+
x_norm = NONLINEARSOLVE_DEFAULT_NORM(x)
46+
init_α = ifelse(fx_norm 1e-5, max(x_norm, T(true)) / (2 * fx_norm), T(true))
47+
else
48+
init_α = inv(alg.alpha)
49+
end
50+
51+
J⁻¹ = __init_identity_jacobian(fx, x, init_α)
2152
@bb J⁻¹δf = copy(x)
2253
@bb xᵀJ⁻¹ = copy(x)
2354
@bb δJ⁻¹n = copy(x)
@@ -26,9 +57,15 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
2657
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
2758
termination_condition)
2859

60+
ls_cache = __get_linesearch(alg) === Val(true) ?
61+
LiFukushimaLineSearch()(prob, fx, x) : nothing
62+
2963
for _ in 1:maxiters
3064
@bb δx = J⁻¹ × vec(fprev)
31-
@bb @. x = xo - δx
65+
@bb δx .*= -1
66+
67+
α = ls_cache === nothing ? true : ls_cache(xo, δx)
68+
@bb @. x = xo + α * δx
3269
fx = __eval_f(prob, fx, x)
3370
@bb @. δf = fx - fprev
3471

@@ -37,7 +74,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
3774
tc_sol !== nothing && return tc_sol
3875

3976
@bb J⁻¹δf = J⁻¹ × vec(δf)
40-
@bb δx .*= -1
4177
d = dot(δx, J⁻¹δf)
4278
@bb xᵀJ⁻¹ = transpose(J⁻¹) × vec(δx)
4379

0 commit comments

Comments
 (0)