Skip to content

Commit 3259cd2

Browse files
Add LineSearchTestCase (#177)
* Add LineSearchTestCase Also includes the failing case in PR#174. Co-authored-by: Mateusz Baran <mateuszbaran89@gmail.com> * Add caching to all line search algorithms * Add to docs * Test caching for all algs --------- Co-authored-by: Mateusz Baran <mateuszbaran89@gmail.com>
1 parent ded667a commit 3259cd2

15 files changed

+238
-25
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
docs/build
55
docs/src/examples/generated
66
/docs/Manifest.toml
7+
Manifest.toml

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LineSearches"
22
uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
3-
version = "7.2.0"
3+
version = "7.3.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -13,6 +13,8 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1313
DoubleFloats = "1"
1414
NLSolversBase = "7"
1515
NaNMath = "1"
16+
Optim = "1"
17+
OptimTestProblems = "2"
1618
Parameters = "0.10, 0.11, 0.12"
1719
julia = "1.6"
1820

docs/src/index.md

+8
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ using LineSearches
4646
```
4747
to load the package.
4848

49+
## Debugging
50+
51+
If you suspect a method of suboptimal performance or find that your code errors,
52+
create a [`LineSearchCache`](@ref) to record intermediate values for later
53+
inspection and analysis. If you're using this via Optim.jl, configure it inside
54+
the method, e.g., `Newton(linesearch=LineSearches.MoreThuente(; cache))`. The
55+
value stored in the cache will reflect the final iteration of line search during
56+
optimization.
4957

5058
## References
5159

docs/src/reference/linesearch.md

+6
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,9 @@ MoreThuente
1111
Static
1212
StrongWolfe
1313
```
14+
15+
## Debugging
16+
17+
```@docs
18+
LineSearchCache
19+
```

src/LineSearches.jl

+23-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
__precompile__()
2-
31
module LineSearches
42

53
using Printf
@@ -9,13 +7,14 @@ using Parameters, NaNMath
97
import NLSolversBase
108
import NLSolversBase: AbstractObjective
119

12-
export LineSearchException
10+
export LineSearchException, LineSearchCache
1311

14-
export BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe
12+
export AbstractLineSearch, BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe
1513

1614
export InitialHagerZhang, InitialStatic, InitialPrevious,
1715
InitialQuadratic, InitialConstantChange
1816

17+
1918
function make_ϕ(df, x_new, x, s)
2019
function ϕ(α)
2120
# Move a distance of alpha in the direction of s
@@ -91,6 +90,26 @@ end
9190

9291
include("types.jl")
9392

93+
# The following don't extend `empty!` and `push!` because we want implementations for `nothing`
94+
# and that would be piracy
95+
emptycache!(cache::LineSearchCache) = begin
96+
empty!(cache.alphas)
97+
empty!(cache.values)
98+
empty!(cache.slopes)
99+
end
100+
emptycache!(::Nothing) = nothing
101+
pushcache!(cache::LineSearchCache, α, val, slope) = begin
102+
push!(cache.alphas, α)
103+
push!(cache.values, val)
104+
push!(cache.slopes, slope)
105+
end
106+
pushcache!(cache::LineSearchCache, α, val) = begin
107+
push!(cache.alphas, α)
108+
push!(cache.values, val)
109+
end
110+
pushcache!(::Nothing, α, val, slope) = nothing
111+
pushcache!(::Nothing, α, val) = nothing
112+
94113
# Line Search Methods
95114
include("backtracking.jl")
96115
include("strongwolfe.jl")

src/backtracking.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ there exists a factor ρ = ρ(c₁) such that α' ≦ ρ α.
88
99
This is a modification of the algorithm described in Nocedal Wright (2nd ed), Sec. 3.5.
1010
"""
11-
@with_kw struct BackTracking{TF, TI}
11+
@with_kw struct BackTracking{TF, TI} <: AbstractLineSearch
1212
c_1::TF = 1e-4
1313
ρ_hi::TF = 0.5
1414
ρ_lo::TF = 0.1
1515
iterations::TI = 1_000
1616
order::TI = 3
1717
maxstep::TF = Inf
18+
cache::Union{Nothing,LineSearchCache{TF}} = nothing
1819
end
1920
BackTracking{TF}(args...; kwargs...) where TF = BackTracking{TF,Int}(args...; kwargs...)
2021

@@ -37,7 +38,9 @@ end
3738

3839
# TODO: Should we deprecate the interface that only uses the ϕ argument?
3940
function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where
40-
@unpack c_1, ρ_hi, ρ_lo, iterations, order = ls
41+
@unpack c_1, ρ_hi, ρ_lo, iterations, order, cache = ls
42+
emptycache!(cache)
43+
pushcache!(cache, 0, ϕ_0, dϕ_0) # backtracking doesn't use the slope except here
4144

4245
iterfinitemax = -log2(eps(real(Tα)))
4346

@@ -68,6 +71,8 @@ function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where Tα
6871

6972
ϕx_1 = ϕ(α_2)
7073
end
74+
pushcache!(cache, αinitial, ϕx_1)
75+
# TODO: check if value is finite (maybe iterfinite > iterfinitemax)
7176

7277
# Backtrack until we satisfy sufficient decrease condition
7378
while ϕx_1 > ϕ_0 + c_1 * α_2 * dϕ_0
@@ -112,6 +117,7 @@ function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where Tα
112117

113118
# Evaluate f(x) at proposed position
114119
ϕx_0, ϕx_1 = ϕx_1, ϕ(α_2)
120+
pushcache!(cache, α_2, ϕx_1)
115121
end
116122

117123
return α_2, ϕx_1

src/hagerzhang.jl

+14-7
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Conjugate gradient line search implementation from:
8080
conjugate gradient method with guaranteed descent. ACM
8181
Transactions on Mathematical Software 32: 113–137.
8282
"""
83-
@with_kw struct HagerZhang{T, Tm}
83+
@with_kw struct HagerZhang{T, Tm} <: AbstractLineSearch
8484
delta::T = DEFAULTDELTA # c_1 Wolfe sufficient decrease condition
8585
sigma::T = DEFAULTSIGMA # c_2 Wolfe curvature condition (Recommend 0.1 for GradientDescent)
8686
alphamax::T = Inf
@@ -91,6 +91,7 @@ Conjugate gradient line search implementation from:
9191
psi3::T = 0.1
9292
display::Int = 0
9393
mayterminate::Tm = Ref{Bool}(false)
94+
cache::Union{Nothing,LineSearchCache{T}} = nothing
9495
end
9596
HagerZhang{T}(args...; kwargs...) where T = HagerZhang{T, Base.RefValue{Bool}}(args...; kwargs...)
9697

@@ -109,9 +110,11 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
109110
phi_0::Real,
110111
dphi_0::Real) where T # Should c and phi_0 be same type?
111112
@unpack delta, sigma, alphamax, rho, epsilon, gamma,
112-
linesearchmax, psi3, display, mayterminate = ls
113+
linesearchmax, psi3, display, mayterminate, cache = ls
114+
emptycache!(cache)
113115

114116
zeroT = convert(T, 0)
117+
pushcache!(cache, zeroT, phi_0, dphi_0)
115118
if !(isfinite(phi_0) && isfinite(dphi_0))
116119
throw(LineSearchException("Value and slope at step length = 0 must be finite.", T(0)))
117120
end
@@ -124,9 +127,13 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
124127
# Prevent values of x_new = x+αs that are likely to make
125128
# ϕ(x_new) infinite
126129
iterfinitemax::Int = ceil(Int, -log2(eps(T)))
127-
alphas = [zeroT] # for bisection
128-
values = [phi_0]
129-
slopes = [dphi_0]
130+
if cache !== nothing
131+
@unpack alphas, values, slopes = cache
132+
else
133+
alphas = [zeroT] # for bisection
134+
values = [phi_0]
135+
slopes = [dphi_0]
136+
end
130137
if display & LINESEARCH > 0
131138
println("New linesearch")
132139
end
@@ -203,10 +210,10 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
203210
else
204211
# We'll still going downhill, expand the interval and try again.
205212
# Reaching this branch means that dphi_c < 0 and phi_c <= phi_0 + ϵ_k
206-
# So cold = c has a lower objective than phi_0 up to epsilon.
213+
# So cold = c has a lower objective than phi_0 up to epsilon.
207214
# This makes it a viable step to return if bracketing fails.
208215

209-
# Bracketing can fail if no cold < c <= alphamax can be found with finite phi_c and dphi_c.
216+
# Bracketing can fail if no cold < c <= alphamax can be found with finite phi_c and dphi_c.
210217
# Going back to the loop with c = cold will only result in infinite cycling.
211218
# So returning (cold, phi_cold) and exiting the line search is the best move.
212219
cold = c

src/morethuente.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,14 @@ The line search implementation from:
138138
Line search algorithms with guaranteed sufficient decrease.
139139
ACM Transactions on Mathematical Software (TOMS) 20.3 (1994): 286-307.
140140
"""
141-
@with_kw struct MoreThuente{T}
141+
@with_kw struct MoreThuente{T} <: AbstractLineSearch
142142
f_tol::T = 1e-4 # c_1 Wolfe sufficient decrease condition
143143
gtol::T = 0.9 # c_2 Wolfe curvature condition (Recommend 0.1 for GradientDescent)
144144
x_tol::T = 1e-8
145145
alphamin::T = 1e-16
146146
alphamax::T = 65536.0
147147
maxfev::Int = 100
148+
cache::Union{Nothing,LineSearchCache{T}} = nothing
148149
end
149150

150151
function (ls::MoreThuente)(df::AbstractObjective, x::AbstractArray{T},
@@ -161,13 +162,15 @@ function (ls::MoreThuente)(ϕdϕ,
161162
alpha::T,
162163
ϕ_0,
163164
dϕ_0) where T
164-
@unpack f_tol, gtol, x_tol, alphamin, alphamax, maxfev = ls
165+
@unpack f_tol, gtol, x_tol, alphamin, alphamax, maxfev, cache = ls
166+
emptycache!(cache)
165167

166168
iterfinitemax = -log2(eps(T))
167169
info = 0
168170
info_cstep = 1 # Info from step
169171

170172
zeroT = convert(T, 0)
173+
pushcache!(cache, zeroT, ϕ_0, dϕ_0)
171174

172175
#
173176
# Check the input parameters for errors.
@@ -236,7 +239,9 @@ function (ls::MoreThuente)(ϕdϕ,
236239
# Make stmax = (3/2)*alpha < 2alpha in the first iteration below
237240
stx = (convert(T, 7)/8)*alpha
238241
end
242+
pushcache!(cache, alpha, f, dg)
239243
# END: Ensure that the initial step provides finite function values
244+
# TODO: check if value is finite (maybe iterfinite > iterfinitemax)
240245

241246
while true
242247
#
@@ -282,6 +287,7 @@ function (ls::MoreThuente)(ϕdϕ,
282287
# and compute the directional derivative.
283288
#
284289
f, dg = ϕdϕ(alpha)
290+
pushcache!(cache, alpha, f, dg)
285291
nfev += 1 # This includes calls to f() and g!()
286292

287293
if isapprox(dg, 0, atol=eps(T)) # Should add atol value to MoreThuente

src/static.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
`Static` is intended for methods with well-scaled updates; i.e. Newton, on well-behaved problems.
55
"""
6-
struct Static end
6+
struct Static <: AbstractLineSearch end
77

88
function (ls::Static)(df::AbstractObjective, x, s, α, x_new = similar(x), ϕ_0 = nothing, dϕ_0 = nothing)
99
ϕ = make_ϕ(df, x_new, x, s)

src/strongwolfe.jl

+18-4
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ use `MoreThuente`, `HagerZhang` or `BackTracking`.
1414
* `c_2 = 0.9` : second (strong) Wolfe condition
1515
* `ρ = 2.0` : bracket growth
1616
"""
17-
@with_kw struct StrongWolfe{T}
17+
@with_kw struct StrongWolfe{T} <: AbstractLineSearch
1818
c_1::T = 1e-4
1919
c_2::T = 0.9
2020
ρ::T = 2.0
21+
cache::Union{Nothing,LineSearchCache{T}} = nothing
2122
end
2223

2324
"""
@@ -49,9 +50,11 @@ Both `alpha` and `ϕ(alpha)` are returned.
4950
"""
5051
function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,
5152
alpha0::T, ϕ_0, dϕ_0) where T<:Real
52-
@unpack c_1, c_2, ρ = ls
53+
@unpack c_1, c_2, ρ, cache = ls
54+
emptycache!(cache)
5355

5456
zeroT = convert(T, 0)
57+
pushcache!(cache, zeroT, ϕ_0, dϕ_0)
5558

5659
# Step-sizes
5760
a_0 = zeroT
@@ -71,17 +74,21 @@ function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,
7174

7275
while a_i < a_max
7376
ϕ_a_i = ϕ(a_i)
77+
pushcache!(cache, a_i, ϕ_a_i)
7478

7579
# Test Wolfe conditions
7680
if (ϕ_a_i > ϕ_0 + c_1 * a_i * dϕ_0) ||
7781
(ϕ_a_i >= ϕ_a_iminus1 && i > 1)
7882
a_star = zoom(a_iminus1, a_i,
7983
dϕ_0, ϕ_0,
80-
ϕ, dϕ, ϕdϕ)
84+
ϕ, dϕ, ϕdϕ, cache)
8185
return a_star, ϕ(a_star)
8286
end
8387

8488
dϕ_a_i = (a_i)
89+
if cache !== nothing
90+
push!(cache.slopes, dϕ_a_i)
91+
end
8592

8693
# Check condition 2
8794
if abs(dϕ_a_i) <= -c_2 * dϕ_0
@@ -91,7 +98,7 @@ function (ls::StrongWolfe)(ϕ, dϕ, ϕdϕ,
9198
# Check condition 3
9299
if dϕ_a_i >= zeroT # FIXME untested!
93100
a_star = zoom(a_i, a_iminus1,
94-
dϕ_0, ϕ_0, ϕ, dϕ, ϕdϕ)
101+
dϕ_0, ϕ_0, ϕ, dϕ, ϕdϕ, cache)
95102
return a_star, ϕ(a_star)
96103
end
97104

@@ -117,6 +124,7 @@ function zoom(a_lo::T,
117124
ϕ,
118125
dϕ,
119126
ϕdϕ,
127+
cache,
120128
c_1::Real = convert(T, 1)/10^4,
121129
c_2::Real = convert(T, 9)/10) where T
122130

@@ -133,8 +141,10 @@ function zoom(a_lo::T,
133141
iteration += 1
134142

135143
ϕ_a_lo, ϕprime_a_lo = ϕdϕ(a_lo)
144+
pushcache!(cache, a_lo, ϕ_a_lo, ϕprime_a_lo)
136145

137146
ϕ_a_hi, ϕprime_a_hi = ϕdϕ(a_hi)
147+
pushcache!(cache, a_hi, ϕ_a_hi, ϕprime_a_hi)
138148

139149
# Interpolate a_j
140150
if a_lo < a_hi
@@ -150,6 +160,7 @@ function zoom(a_lo::T,
150160

151161
# Evaluate ϕ(a_j)
152162
ϕ_a_j = ϕ(a_j)
163+
pushcache!(cache, a_j, ϕ_a_j)
153164

154165
# Check Armijo
155166
if (ϕ_a_j > ϕ_0 + c_1 * a_j * dϕ_0) ||
@@ -158,6 +169,9 @@ function zoom(a_lo::T,
158169
else
159170
# Evaluate ϕprime(a_j)
160171
ϕprime_a_j = (a_j)
172+
if cache !== nothing
173+
push!(cache.slopes, ϕprime_a_j)
174+
end
161175

162176
if abs(ϕprime_a_j) <= -c_2 * dϕ_0
163177
return a_j

src/types.jl

+36
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,39 @@ mutable struct LineSearchException{T<:Real} <: Exception
22
message::AbstractString
33
alpha::T
44
end
5+
6+
abstract type AbstractLineSearch end
7+
8+
# For debugging
9+
struct LineSearchCache{T}
10+
alphas::Vector{T}
11+
values::Vector{T}
12+
slopes::Vector{T}
13+
end
14+
"""
15+
cache = LineSearchCache{T}()
16+
17+
Initialize an empty cache for storing intermediate results during line search.
18+
The `α`, `ϕ(α)`, and possibly `dϕ(α)` values computed during line search are
19+
available in `cache.alphas`, `cache.values`, and `cache.slopes`, respectively.
20+
21+
# Example
22+
23+
```jldoctest
24+
julia> ϕ(x) = (x - π)^4; dϕ(x) = 4*(x-π)^3;
25+
26+
julia> cache = LineSearchCache{Float64}();
27+
28+
julia> ls = BackTracking(; cache);
29+
30+
julia> ls(ϕ, 10.0, ϕ(0), dϕ(0))
31+
(1.8481462933284658, 2.7989406670901373)
32+
33+
julia> cache
34+
LineSearchCache{Float64}([0.0, 10.0, 1.8481462933284658], [97.40909103400242, 2212.550050116452, 2.7989406670901373], [-124.02510672119926])
35+
```
36+
37+
Because `BackTracking` doesn't use derivatives except at `α=0`, only the initial slope was stored in the cache.
38+
Other methods may store all three.
39+
"""
40+
LineSearchCache{T}() where T = LineSearchCache{T}(T[], T[], T[])

test/REQUIRE

-3
This file was deleted.

0 commit comments

Comments
 (0)