Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit ade88a2

Browse files
authored
Merge pull request #118 from SciML/ap/static_kernels
Using SimpleNonlinearSolve in GPU Kernels
2 parents 5e71f5c + 25a51ba commit ade88a2

File tree

11 files changed

+297
-43
lines changed

11 files changed

+297
-43
lines changed

.buildkite/pipeline.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
steps:
2+
- label: "Julia 1"
3+
plugins:
4+
- JuliaCI/julia#v1:
5+
version: "1"
6+
- JuliaCI/julia-test#v1:
7+
agents:
8+
queue: "juliagpu"
9+
cuda: "*"
10+
timeout_in_minutes: 30
11+
# Don't run Buildkite if the commit message includes the text [skip tests]
12+
if: build.message !~ /\[skip tests\]/
13+
14+
env:
15+
GROUP: CUDA
16+
JULIA_PKG_SERVER: ""

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.2.0"
4+
version = "1.2.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
[![codecov](https://codecov.io/gh/SciML/SimpleNonlinearSolve.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/SciML/SimpleNonlinearSolve.jl)
77
[![Build Status](https://github.com/SciML/SimpleNonlinearSolve.jl/workflows/CI/badge.svg)](https://github.com/SciML/SimpleNonlinearSolve.jl/actions?query=workflow%3ACI)
8+
[![Build status](https://badge.buildkite.com/c5f7db4f1b5e8a592514378b6fc807d934546cc7d5aa79d645.svg?branch=main)](https://buildkite.com/julialang/simplenonlinearsolve-dot-jl)
89

910
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
1011
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

src/SimpleNonlinearSolve.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing,
5959
return solve(prob, ITP(), args...; kwargs...)
6060
end
6161

62+
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
63+
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
64+
args...; kwargs...)
65+
return SciMLBase.__solve(prob, alg, args...; kwargs...)
66+
end
67+
6268
@setup_workload begin
6369
for T in (Float32, Float64)
6470
prob_no_brack_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

src/nlsolve/dfsane.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
3-
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
3+
M::Union{Int, Val} = Val(10), γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
44
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2)
55
66
A low-overhead implementation of the df-sane method for solving large-scale nonlinear
@@ -42,21 +42,27 @@ see the paper [1].
4242
information for solving large-scale nonlinear systems of equations, Mathematics of
4343
Computation, 75, 1429-1448.
4444
"""
45-
@kwdef @concrete struct SimpleDFSane <: AbstractSimpleNonlinearSolveAlgorithm
46-
σ_min = 1e-10
47-
σ_max = 1e10
48-
σ_1 = 1.0
49-
M::Int = 10
50-
γ = 1e-4
51-
τ_min = 0.1
52-
τ_max = 0.5
53-
nexp::Int = 2
54-
η_strategy = (f_1, k, x, F) -> f_1 ./ k^2
45+
@concrete struct SimpleDFSane{M} <: AbstractSimpleNonlinearSolveAlgorithm
46+
σ_min
47+
σ_max
48+
σ_1
49+
γ
50+
τ_min
51+
τ_max
52+
nexp::Int
53+
η_strategy
5554
end
5655

57-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, args...;
56+
function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
57+
M::Union{Int, Val} = Val(10), γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
58+
nexp::Int = 2, η_strategy::F = (f_1, k, x, F) -> f_1 ./ k^2) where {F}
59+
return SimpleDFSane{SciMLBase._unwrap_val(M)}(σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp,
60+
η_strategy)
61+
end
62+
63+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
5864
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
59-
termination_condition = nothing, kwargs...)
65+
termination_condition = nothing, kwargs...) where {M}
6066
x = __maybe_unaliased(prob.u0, alias_u0)
6167
fx = _get_fx(prob, x)
6268
T = eltype(x)
@@ -65,7 +71,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, args...;
6571
σ_max = T(alg.σ_max)
6672
σ_k = T(alg.σ_1)
6773

68-
(; M, nexp, η_strategy) = alg
74+
(; nexp, η_strategy) = alg
6975
γ = T(alg.γ)
7076
τ_min = T(alg.τ_min)
7177
τ_max = T(alg.τ_max)
@@ -77,7 +83,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, args...;
7783
α_1 = one(T)
7884
f_1 = fx_norm
7985

80-
history_f_k = fill(fx_norm, M)
86+
history_f_k = if x isa SArray
87+
ones(SVector{M, T}) * fx_norm
88+
else
89+
fill(fx_norm, M)
90+
end
8191

8292
# Generate the cache
8393
@bb x_cache = similar(x)
@@ -143,7 +153,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, args...;
143153
fx_norm = fx_norm_new
144154

145155
# Store function value
146-
history_f_k[mod1(k, M)] = fx_norm_new
156+
if history_f_k isa SVector
157+
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
158+
else
159+
history_f_k[mod1(k, M)] = fx_norm_new
160+
end
147161
k += 1
148162
end
149163

src/nlsolve/lbroyden.jl

Lines changed: 163 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,22 @@ function SimpleLimitedMemoryBroyden(; threshold::Union{Val, Int} = Val(27))
2121
return SimpleLimitedMemoryBroyden{SciMLBase._unwrap_val(threshold)}()
2222
end
2323

24-
@views function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
24+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
25+
args...; termination_condition = nothing, kwargs...)
26+
if prob.u0 isa SArray
27+
if termination_condition === nothing ||
28+
termination_condition isa AbsNormTerminationMode
29+
return __static_solve(prob, alg, args...; termination_condition, kwargs...)
30+
end
31+
@warn "Specifying `termination_condition = $(termination_condition)` for \
32+
`SimpleLimitedMemoryBroyden` with `SArray` is not non-allocating. Use \
33+
either `termination_condition = AbsNormTerminationMode()` or \
34+
`termination_condition = nothing`." maxlog=1
35+
end
36+
return __generic_solve(prob, alg, args...; termination_condition, kwargs...)
37+
end
38+
39+
@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
2540
args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
2641
termination_condition = nothing, kwargs...)
2742
x = __maybe_unaliased(prob.u0, alias_u0)
@@ -36,7 +51,7 @@ end
3651

3752
fx = _get_fx(prob, x)
3853

39-
U, Vᵀ = __init_low_rank_jacobian(x, fx, threshold)
54+
U, Vᵀ = __init_low_rank_jacobian(x, fx, x isa StaticArray ? threshold : Val(η))
4055

4156
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
4257
termination_condition)
@@ -48,7 +63,7 @@ end
4863
@bb δf = copy(fx)
4964

5065
@bb vᵀ_cache = copy(x)
51-
Tcache = __lbroyden_threshold_cache(x, threshold)
66+
Tcache = __lbroyden_threshold_cache(x, x isa StaticArray ? threshold : Val(η))
5267
@bb mat_cache = copy(x)
5368

5469
for i in 1:maxiters
@@ -83,6 +98,97 @@ end
8398
return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
8499
end
85100

101+
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
102+
# finicky, so we'll implement it separately from the generic version
103+
# Ignore termination_condition. Don't pass things into internal functions
104+
function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
105+
args...; abstol = nothing, maxiters = 1000, kwargs...)
106+
x = prob.u0
107+
fx = _get_fx(prob, x)
108+
threshold = __get_threshold(alg)
109+
110+
U, Vᵀ = __init_low_rank_jacobian(vec(x), vec(fx), threshold)
111+
112+
abstol = DiffEqBase._get_tolerance(abstol, eltype(x))
113+
114+
xo, δx, fo, δf = x, -fx, fx, fx
115+
116+
converged, res = __unrolled_lbroyden_initial_iterations(prob, xo, fo, δx, abstol, U, Vᵀ,
117+
threshold)
118+
119+
converged &&
120+
return build_solution(prob, alg, res.x, res.fx; retcode = ReturnCode.Success)
121+
122+
xo, fo, δx = res.x, res.fx, res.δx
123+
124+
for i in 1:(maxiters - SciMLBase._unwrap_val(threshold))
125+
x = xo .+ δx
126+
fx = prob.f(x, prob.p)
127+
δf = fx - fo
128+
129+
maximum(abs, fx) abstol &&
130+
return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
131+
132+
vᵀ = _restructure(x, _rmatvec!!(U, Vᵀ, vec(δx)))
133+
mvec = _restructure(x, _matvec!!(U, Vᵀ, vec(δf)))
134+
135+
d = dot(vᵀ, δf)
136+
δx = @. (δx - mvec) / d
137+
138+
U = Base.setindex(U, vec(δx), mod1(i, SciMLBase._unwrap_val(threshold)))
139+
Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), mod1(i, SciMLBase._unwrap_val(threshold)))
140+
141+
δx = -_restructure(fx, _matvec!!(U, Vᵀ, vec(fx)))
142+
143+
xo = x
144+
fo = fx
145+
end
146+
147+
return build_solution(prob, alg, xo, fo; retcode = ReturnCode.MaxIters)
148+
end
149+
150+
@generated function __unrolled_lbroyden_initial_iterations(prob, xo, fo, δx, abstol, U,
151+
Vᵀ, ::Val{threshold}) where {threshold}
152+
calls = []
153+
for i in 1:threshold
154+
static_idx, static_idx_p1 = Val(i - 1), Val(i)
155+
push!(calls,
156+
quote
157+
x = xo .+ δx
158+
fx = prob.f(x, prob.p)
159+
δf = fx - fo
160+
161+
maximum(abs, fx) abstol && return true, (; x, fx, δx)
162+
163+
_U = __first_n_getindex(U, $(static_idx))
164+
_Vᵀ = __first_n_getindex(Vᵀ, $(static_idx))
165+
166+
vᵀ = _restructure(x, _rmatvec!!(_U, _Vᵀ, vec(δx)))
167+
mvec = _restructure(x, _matvec!!(_U, _Vᵀ, vec(δf)))
168+
169+
d = dot(vᵀ, δf)
170+
δx = @. (δx - mvec) / d
171+
172+
U = Base.setindex(U, vec(δx), $(i))
173+
Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), $(i))
174+
175+
_U = __first_n_getindex(U, $(static_idx_p1))
176+
_Vᵀ = __first_n_getindex(Vᵀ, $(static_idx_p1))
177+
δx = -_restructure(fx, _matvec!!(_U, _Vᵀ, vec(fx)))
178+
179+
xo = x
180+
fo = fx
181+
end)
182+
end
183+
push!(calls, quote
184+
# Termination Check
185+
maximum(abs, fx) abstol && return true, (; x, fx, δx)
186+
187+
return false, (; x, fx, δx)
188+
end)
189+
return Expr(:block, calls...)
190+
end
191+
86192
function _rmatvec!!(y, xᵀU, U, Vᵀ, x)
87193
# xᵀ × (-I + UVᵀ)
88194
η = size(U, 2)
@@ -98,6 +204,9 @@ function _rmatvec!!(y, xᵀU, U, Vᵀ, x)
98204
return y
99205
end
100206

207+
@inline _rmatvec!!(::Nothing, Vᵀ, x) = -x
208+
@inline _rmatvec!!(U, Vᵀ, x) = __mapTdot(__mapdot(x, U), Vᵀ) .- x
209+
101210
function _matvec!!(y, Vᵀx, U, Vᵀ, x)
102211
# (-I + UVᵀ) × x
103212
η = size(U, 2)
@@ -113,7 +222,56 @@ function _matvec!!(y, Vᵀx, U, Vᵀ, x)
113222
return y
114223
end
115224

225+
@inline _matvec!!(::Nothing, Vᵀ, x) = -x
226+
@inline _matvec!!(U, Vᵀ, x) = __mapTdot(__mapdot(x, Vᵀ), U) .- x
227+
228+
function __mapdot(x::SVector{S1}, Y::SVector{S2, <:SVector{S1}}) where {S1, S2}
229+
return map(Base.Fix1(dot, x), Y)
230+
end
231+
@generated function __mapTdot(x::SVector{S1}, Y::SVector{S1, <:SVector{S2}}) where {S1, S2}
232+
calls = []
233+
syms = [gensym("m$(i)") for i in 1:S1]
234+
for i in 1:S1
235+
push!(calls, :($(syms[i]) = x[$(i)] .* Y[$i]))
236+
end
237+
push!(calls, :(return .+($(syms...))))
238+
return Expr(:block, calls...)
239+
end
240+
241+
@generated function __first_n_getindex(x::SVector{L, T}, ::Val{N}) where {L, T, N}
242+
@assert N L
243+
getcalls = ntuple(i -> :(x[$i]), N)
244+
N == 0 && return :(return nothing)
245+
return :(return SVector{$N, $T}(($(getcalls...))))
246+
end
247+
116248
__lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = similar(x, threshold)
117-
function __lbroyden_threshold_cache(x::SArray, ::Val{threshold}) where {threshold}
118-
return SArray{Tuple{threshold}, eltype(x)}(ntuple(_ -> zero(eltype(x)), threshold))
249+
function __lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold}
250+
return zeros(MArray{Tuple{threshold}, eltype(x)})
251+
end
252+
__lbroyden_threshold_cache(x::SArray, ::Val{threshold}) where {threshold} = nothing
253+
254+
function __init_low_rank_jacobian(u::StaticArray{S1, T1}, fu::StaticArray{S2, T2},
255+
::Val{threshold}) where {S1, S2, T1, T2, threshold}
256+
T = promote_type(T1, T2)
257+
fuSize, uSize = Size(fu), Size(u)
258+
Vᵀ = MArray{Tuple{threshold, prod(uSize)}, T}(undef)
259+
U = MArray{Tuple{prod(fuSize), threshold}, T}(undef)
260+
return U, Vᵀ
261+
end
262+
@generated function __init_low_rank_jacobian(u::SVector{Lu, T1}, fu::SVector{Lfu, T2},
263+
::Val{threshold}) where {Lu, Lfu, T1, T2, threshold}
264+
T = promote_type(T1, T2)
265+
inner_inits_Vᵀ = [:(zeros(SVector{$Lu, $T})) for i in 1:threshold]
266+
inner_inits_U = [:(zeros(SVector{$Lfu, $T})) for i in 1:threshold]
267+
return quote
268+
Vᵀ = SVector($(inner_inits_Vᵀ...))
269+
U = SVector($(inner_inits_U...))
270+
return U, Vᵀ
271+
end
272+
end
273+
function __init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
274+
Vᵀ = similar(u, threshold, length(u))
275+
U = similar(u, length(fu), threshold)
276+
return U, Vᵀ
119277
end

src/utils.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -243,20 +243,6 @@ function __init_identity_jacobian!!(J::SVector{S1}) where {S1}
243243
return ones(SVector{S1, eltype(J)})
244244
end
245245

246-
function __init_low_rank_jacobian(u::StaticArray{S1, T1}, fu::StaticArray{S2, T2},
247-
::Val{threshold}) where {S1, S2, T1, T2, threshold}
248-
T = promote_type(T1, T2)
249-
fuSize, uSize = Size(fu), Size(u)
250-
Vᵀ = MArray{Tuple{threshold, prod(uSize)}, T}(undef)
251-
U = MArray{Tuple{prod(fuSize), threshold}, T}(undef)
252-
return U, Vᵀ
253-
end
254-
function __init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
255-
Vᵀ = similar(u, threshold, length(u))
256-
U = similar(u, length(fu), threshold)
257-
return U, Vᵀ
258-
end
259-
260246
@inline _vec(v) = vec(v)
261247
@inline _vec(v::Number) = v
262248
@inline _vec(v::AbstractVector) = v

test/basictests.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ end
164164
## SimpleDFSane needs to allocate a history vector
165165
@testset "Allocation Checks: $(_nameof(alg))" for alg in (SimpleNewtonRaphson(),
166166
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleLimitedMemoryBroyden(),
167-
SimpleTrustRegion())
168-
@check_allocs nlsolve(prob, alg) = DiffEqBase.__solve(prob, alg; abstol = 1e-9)
167+
SimpleTrustRegion(), SimpleDFSane())
168+
@check_allocs nlsolve(prob, alg) = SciMLBase.solve(prob, alg; abstol = 1e-9)
169169

170170
nlprob_scalar = NonlinearProblem{false}(quadratic_f, 1.0, 2.0)
171171
nlprob_sa = NonlinearProblem{false}(quadratic_f, @SVector[1.0, 1.0], 2.0)
@@ -175,18 +175,17 @@ end
175175
@test true
176176
catch e
177177
@error e
178-
@test false
178+
# History Vector Allocates
179+
@test false broken=(alg isa SimpleDFSane)
179180
end
180181

181182
# ForwardDiff allocates for hessian since we don't propagate the chunksize
182-
# SimpleLimitedMemoryBroyden needs to do views on the low rank matrices so the sizes
183-
# are dynamic. This can be fixed but no without maintaining the simplicity of the code
184183
try
185184
nlsolve(nlprob_sa, alg)
186185
@test true
187186
catch e
188187
@error e
189-
@test false broken=(alg isa SimpleHalley || alg isa SimpleLimitedMemoryBroyden)
188+
@test false broken=(alg isa SimpleHalley)
190189
end
191190
end
192191

0 commit comments

Comments
 (0)