|
1 | 1 | """
|
2 |
| - LBroyden(threshold::Int = 27) |
| 2 | + LBroyden(; batched = false, |
| 3 | + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; |
| 4 | + abstol = nothing, reltol = nothing), |
| 5 | + threshold::Int = 27) |
3 | 6 |
|
4 | 7 | A limited memory implementation of Broyden. This method applies the L-BFGS scheme to
|
5 | 8 | Broyden's method.
|
| 9 | +
|
| 10 | +!!! warn |
| 11 | +
|
| 12 | + This method is not very stable and can diverge even for very simple problems. This has mostly been |
| 13 | + tested for neural networks in DeepEquilibriumNetworks.jl. |
6 | 14 | """
|
7 |
| -Base.@kwdef struct LBroyden <: AbstractSimpleNonlinearSolveAlgorithm |
8 |
| - threshold::Int = 27 |
| 15 | +struct LBroyden{batched, TC <: NLSolveTerminationCondition} <: |
| 16 | + AbstractSimpleNonlinearSolveAlgorithm |
| 17 | + termination_condition::TC |
| 18 | + threshold::Int |
| 19 | + |
| 20 | + function LBroyden(; batched = false, threshold::Int = 27, |
| 21 | + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; |
| 22 | + abstol = nothing, |
| 23 | + reltol = nothing)) |
| 24 | + return new{batched, typeof(termination_condition)}(termination_condition, threshold) |
| 25 | + end |
9 | 26 | end
|
10 | 27 |
|
11 |
| -@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...; |
| 28 | +@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden{batched}, args...; |
12 | 29 | abstol = nothing, reltol = nothing, maxiters = 1000,
|
13 |
| - batch = false, kwargs...) |
| 30 | + kwargs...) where {batched} |
| 31 | + tc = alg.termination_condition |
| 32 | + mode = DiffEqBase.get_termination_mode(tc) |
14 | 33 | threshold = min(maxiters, alg.threshold)
|
15 | 34 | x = float(prob.u0)
|
16 | 35 |
|
| 36 | + batched && @assert ndims(x)==2 "Batched LBroyden only supports 2D arrays" |
| 37 | + |
17 | 38 | if x isa Number
|
18 | 39 | restore_scalar = true
|
19 | 40 | x = [x]
|
|
30 | 51 | error("LBroyden currently only supports out-of-place nonlinear problems")
|
31 | 52 | end
|
32 | 53 |
|
33 |
| - U = fill!(similar(x, (threshold, length(x))), zero(T)) |
34 |
| - Vᵀ = fill!(similar(x, (length(x), threshold)), zero(T)) |
| 54 | + U, Vᵀ = _init_lbroyden_state(batched, x, threshold) |
35 | 55 |
|
36 | 56 | atol = abstol !== nothing ? abstol :
|
37 |
| - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) |
38 |
| - rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) |
| 57 | + (tc.abstol !== nothing ? tc.abstol : |
| 58 | + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) |
| 59 | + rtol = reltol !== nothing ? reltol : |
| 60 | + (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) |
| 61 | + |
| 62 | + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES |
| 63 | + error("LBroyden currently doesn't support SAFE_BEST termination modes") |
| 64 | + end |
| 65 | + |
| 66 | + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing |
| 67 | + termination_condition = tc(storage) |
39 | 68 |
|
40 | 69 | xₙ = x
|
41 | 70 | xₙ₋₁ = x
|
|
47 | 76 | Δxₙ = xₙ .- xₙ₋₁
|
48 | 77 | Δfₙ = fₙ .- fₙ₋₁
|
49 | 78 |
|
50 |
| - if iszero(fₙ) |
51 |
| - xₙ = restore_scalar ? xₙ[] : xₙ |
52 |
| - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) |
53 |
| - end |
54 |
| - |
55 |
| - if isapprox(xₙ, xₙ₋₁; atol, rtol) |
| 79 | + if termination_condition(restore_scalar ? [fₙ] : fₙ, xₙ, xₙ₋₁, atol, rtol) |
56 | 80 | xₙ = restore_scalar ? xₙ[] : xₙ
|
57 | 81 | return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
|
58 | 82 | end
|
59 | 83 |
|
60 |
| - _U = U[1:min(threshold, i), :] |
61 |
| - _Vᵀ = Vᵀ[:, 1:min(threshold, i)] |
| 84 | + _U = selectdim(U, 1, 1:min(threshold, i)) |
| 85 | + _Vᵀ = selectdim(Vᵀ, 2, 1:min(threshold, i)) |
62 | 86 |
|
63 | 87 | vᵀ = _rmatvec(_U, _Vᵀ, Δxₙ)
|
64 | 88 | mvec = _matvec(_U, _Vᵀ, Δfₙ)
|
65 |
| - Δxₙ = (Δxₙ .- mvec) ./ (sum(vᵀ .* Δfₙ) .+ convert(T, 1e-5)) |
| 89 | + u = (Δxₙ .- mvec) ./ (sum(vᵀ .* Δfₙ) .+ convert(T, 1e-5)) |
66 | 90 |
|
67 |
| - Vᵀ[:, mod1(i, threshold)] .= vᵀ |
68 |
| - U[mod1(i, threshold), :] .= Δxₙ |
| 91 | + selectdim(Vᵀ, 2, mod1(i, threshold)) .= vᵀ |
| 92 | + selectdim(U, 1, mod1(i, threshold)) .= u |
69 | 93 |
|
70 |
| - update = -_matvec(U[1:min(threshold, i + 1), :], Vᵀ[:, 1:min(threshold, i + 1)], fₙ) |
| 94 | + update = -_matvec(selectdim(U, 1, 1:min(threshold, i + 1)), |
| 95 | + selectdim(Vᵀ, 2, 1:min(threshold, i + 1)), fₙ) |
71 | 96 |
|
72 | 97 | xₙ₋₁ = xₙ
|
73 | 98 | fₙ₋₁ = fₙ
|
|
77 | 102 | return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)
|
78 | 103 | end
|
79 | 104 |
|
| 105 | +function _init_lbroyden_state(batched::Bool, x, threshold) |
| 106 | + T = eltype(x) |
| 107 | + if batched |
| 108 | + U = fill!(similar(x, (threshold, size(x, 1), size(x, 2))), zero(T)) |
| 109 | + Vᵀ = fill!(similar(x, (size(x, 1), threshold, size(x, 2))), zero(T)) |
| 110 | + else |
| 111 | + U = fill!(similar(x, (threshold, length(x))), zero(T)) |
| 112 | + Vᵀ = fill!(similar(x, (length(x), threshold)), zero(T)) |
| 113 | + end |
| 114 | + return U, Vᵀ |
| 115 | +end |
| 116 | + |
80 | 117 | function _rmatvec(U::AbstractMatrix, Vᵀ::AbstractMatrix,
|
81 | 118 | x::Union{<:AbstractVector, <:Number})
|
82 |
| - return -x .+ dropdims(sum(U .* sum(Vᵀ .* x; dims = 1)'; dims = 1); dims = 1) |
| 119 | + length(U) == 0 && return x |
| 120 | + return -x .+ vec((x' * Vᵀ) * U) |
| 121 | +end |
| 122 | + |
| 123 | +function _rmatvec(U::AbstractArray{T1, 3}, Vᵀ::AbstractArray{T2, 3}, |
| 124 | + x::AbstractMatrix) where {T1, T2} |
| 125 | + length(U) == 0 && return x |
| 126 | + Vᵀx = sum(Vᵀ .* reshape(x, size(x, 1), 1, size(x, 2)); dims = 1) |
| 127 | + return -x .+ _drdims_sum(U .* permutedims(Vᵀx, (2, 1, 3)); dims = 1) |
83 | 128 | end
|
84 | 129 |
|
85 | 130 | function _matvec(U::AbstractMatrix, Vᵀ::AbstractMatrix,
|
86 | 131 | x::Union{<:AbstractVector, <:Number})
|
87 |
| - return -x .+ dropdims(sum(sum(x .* U'; dims = 1) .* Vᵀ; dims = 2); dims = 2) |
| 132 | + length(U) == 0 && return x |
| 133 | + return -x .+ vec(Vᵀ * (U * x)) |
88 | 134 | end
|
| 135 | + |
| 136 | +function _matvec(U::AbstractArray{T1, 3}, Vᵀ::AbstractArray{T2, 3}, |
| 137 | + x::AbstractMatrix) where {T1, T2} |
| 138 | + length(U) == 0 && return x |
| 139 | + xUᵀ = sum(reshape(x, size(x, 1), 1, size(x, 2)) .* permutedims(U, (2, 1, 3)); dims = 1) |
| 140 | + return -x .+ _drdims_sum(xUᵀ .* Vᵀ; dims = 2) |
| 141 | +end |
| 142 | + |
| 143 | +_drdims_sum(args...; dims = :) = dropdims(sum(args...; dims); dims) |
0 commit comments