Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit b7424c8

Browse files
committed
Add implementation for NLLS Forward Mode
1 parent 709012b commit b7424c8

File tree

8 files changed

+148
-30
lines changed

8 files changed

+148
-30
lines changed

Project.toml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
11+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1112
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1213
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1314
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -32,26 +33,41 @@ SimpleNonlinearSolveZygoteExt = "Zygote"
3233

3334
[compat]
3435
ADTypes = "0.2.6"
36+
AllocCheck = "0.1.1"
3537
ArrayInterface = "7.7"
38+
Aqua = "0.8"
39+
CUDA = "5.2"
3640
ChainRulesCore = "1.21"
3741
ConcreteStructs = "0.2.3"
3842
DiffEqBase = "6.146"
43+
DiffResults = "1.1"
3944
FastClosures = "0.3"
4045
FiniteDiff = "2.22"
4146
ForwardDiff = "0.10.36"
4247
LinearAlgebra = "1.10"
48+
LinearSolve = "2.25"
4349
MaybeInplace = "0.1.1"
50+
NonlinearProblemLibrary = "0.1.2"
51+
Pkg = "1.10"
52+
PolyesterForwardDiff = "0.1.1"
4453
PrecompileTools = "1.2"
54+
Random = "1.10"
55+
ReTestItems = "1.23"
4556
Reexport = "1.2"
4657
SciMLBase = "2.23"
58+
SciMLSensitivity = "7.56"
4759
StaticArrays = "1.9"
4860
StaticArraysCore = "1.4.2"
61+
Test = "1.10"
62+
Zygote = "0.6.69"
4963
julia = "1.10"
5064

5165
[extras]
5266
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
67+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5368
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5469
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
70+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
5571
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5672
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5773
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
@@ -67,4 +83,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6783
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6884

6985
[targets]
70-
test = ["AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test"]
86+
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]

ext/SimpleNonlinearSolveZygoteExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
module SimpleNonlinearSolveZygoteExt
22

3-
import SimpleNonlinearSolve
3+
import SimpleNonlinearSolve, Zygote
44

55
SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true
66

7+
function SimpleNonlinearSolve.__zygote_compute_nlls_vjp(f::F, u, p) where {F}
8+
y, pb = Zygote.pullback(Base.Fix2(f, p), u)
9+
return 2 .* only(pb(y))
10+
end
11+
712
end

src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
1111
AbstractSafeBestNonlinearTerminationMode,
1212
NonlinearSafeTerminationReturnCode, get_termination_mode,
1313
NONLINEARSOLVE_DEFAULT_NORM
14+
import DiffResults
1415
import ForwardDiff: Dual
1516
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
1617
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val

src/ad.jl

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,17 @@ function SciMLBase.solve(
55
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
66
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
77
return SciMLBase.build_solution(
8-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
9-
sol.original)
8+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
9+
end
10+
11+
function SciMLBase.solve(
12+
prob::NonlinearLeastSquaresProblem{<:AbstractArray,
13+
iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
14+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
15+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
16+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
17+
return SciMLBase.build_solution(
18+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1019
end
1120

1221
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -56,6 +65,79 @@ function __nlsolve_ad(
5665
return sol, partials
5766
end
5867

68+
function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
69+
p = value(prob.p)
70+
u0 = value(prob.u0)
71+
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)
72+
73+
sol = solve(newprob, alg, args...; kwargs...)
74+
75+
uu = sol.u
76+
77+
if !SciMLBase.has_jac(prob.f)
78+
if isinplace(prob)
79+
_F = @closure (du, u, p) -> begin
80+
resid = similar(du, length(sol.resid))
81+
res = DiffResults.DiffResult(
82+
resid, similar(du, length(sol.resid), length(u)))
83+
_f = @closure (du, u) -> prob.f(du, u, p)
84+
ForwardDiff.jacobian!(res, _f, resid, u)
85+
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
86+
DiffResults.jacobian(res), 2, false)
87+
return nothing
88+
end
89+
else
90+
# For small problems, nesting ForwardDiff is actually quite fast
91+
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
92+
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)
93+
else
94+
_F = @closure (u, p) -> begin
95+
T = promote_type(eltype(u), eltype(p))
96+
res = DiffResults.DiffResult(
97+
similar(u, T, size(sol.resid)), similar(
98+
u, T, length(sol.resid), length(u)))
99+
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
100+
return reshape(
101+
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
102+
size(u))
103+
end
104+
end
105+
end
106+
else
107+
if isinplace(prob)
108+
_F = @closure (du, u, p) -> begin
109+
J = similar(du, length(sol.resid), length(u))
110+
prob.jac(J, u, p)
111+
resid = similar(du, length(sol.resid))
112+
prob.f(resid, u, p)
113+
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
114+
return nothing
115+
end
116+
else
117+
_F = @closure (u, p) -> begin
118+
return reshape(2 .* vec(prob.f(u, p))' * prob.jac(u, p), size(u))
119+
end
120+
end
121+
end
122+
123+
f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
124+
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)
125+
126+
z_arr = -f_x \ f_p
127+
128+
pp = prob.p
129+
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
130+
if uu isa Number
131+
partials = sum(sumfun, zip(z_arr, pp))
132+
elseif p isa Number
133+
partials = sumfun((z_arr, pp))
134+
else
135+
partials = sum(sumfun, zip(eachcol(z_arr), pp))
136+
end
137+
138+
return sol, partials
139+
end
140+
59141
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
60142
if isinplace(prob)
61143
__f = p -> begin

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,6 @@ function __get_tolerance(x::Union{SArray, Number}, ::Nothing, ::Type{T}) where {
388388
η = real(oneunit(T)) * (eps(real(one(T))))^(real(T)(0.8))
389389
return T(η)
390390
end
391+
392+
# Extension
393+
function __zygote_compute_nlls_vjp end

test/core/aqua_tests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@testitem "Aqua" begin
2+
using Aqua
3+
4+
Aqua.test_all(SimpleNonlinearSolve; piracies = false, ambiguities = false)
5+
Aqua.test_piracies(SimpleNonlinearSolve;
6+
treat_as_own = [
7+
NonlinearProblem, NonlinearLeastSquaresProblem, IntervalNonlinearProblem])
8+
Aqua.test_ambiguities(SimpleNonlinearSolve; recursive = false)
9+
end

test/core/forward_ad_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testsetup module ForwardADTesting
1+
@testsetup module ForwardADRootfindingTesting
22
using Reexport
33
@reexport using ForwardDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra
44
import SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm
@@ -40,7 +40,7 @@ __compatible(::SimpleHalley, ::Val{:iip}) = false
4040
export test_f, test_f!, jacobian_f, solve_with, __compatible
4141
end
4242

43-
@testitem "ForwardDiff.jl Integration" setup=[ForwardADTesting] begin
43+
@testitem "ForwardDiff.jl Integration: Rootfinding" setup=[ForwardADRootfindingTesting] begin
4444
@testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(),
4545
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
4646
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())

test/core/rootfind_tests.jl

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -134,31 +134,33 @@ end
134134
end
135135

136136
@testitem "Allocation Checks" setup=[RootfindingTesting] begin
137-
@testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(),
138-
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleLimitedMemoryBroyden(),
139-
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
140-
SimpleDFSane(), SimpleBroyden(; linesearch = Val(true)),
141-
SimpleLimitedMemoryBroyden(; linesearch = Val(true)))
142-
@check_allocs nlsolve(prob, alg) = SciMLBase.solve(prob, alg; abstol = 1e-9)
143-
144-
nlprob_scalar = NonlinearProblem{false}(quadratic_f, 1.0, 2.0)
145-
nlprob_sa = NonlinearProblem{false}(quadratic_f, @SVector[1.0, 1.0], 2.0)
146-
147-
try
148-
nlsolve(nlprob_scalar, alg)
149-
@test true
150-
catch e
151-
@error e
152-
@test false
153-
end
137+
if Sys.islinux() # Very slow on other OS
138+
@testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(),
139+
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleLimitedMemoryBroyden(),
140+
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
141+
SimpleDFSane(), SimpleBroyden(; linesearch = Val(true)),
142+
SimpleLimitedMemoryBroyden(; linesearch = Val(true)))
143+
@check_allocs nlsolve(prob, alg) = SciMLBase.solve(prob, alg; abstol = 1e-9)
144+
145+
nlprob_scalar = NonlinearProblem{false}(quadratic_f, 1.0, 2.0)
146+
nlprob_sa = NonlinearProblem{false}(quadratic_f, @SVector[1.0, 1.0], 2.0)
147+
148+
try
149+
nlsolve(nlprob_scalar, alg)
150+
@test true
151+
catch e
152+
@error e
153+
@test false
154+
end
154155

155-
# ForwardDiff allocates for hessian since we don't propagate the chunksize
156-
try
157-
nlsolve(nlprob_sa, alg)
158-
@test true
159-
catch e
160-
@error e
161-
@test false broken=(alg isa SimpleHalley)
156+
# ForwardDiff allocates for hessian since we don't propagate the chunksize
157+
try
158+
nlsolve(nlprob_sa, alg)
159+
@test true
160+
catch e
161+
@error e
162+
@test false broken=(alg isa SimpleHalley)
163+
end
162164
end
163165
end
164166
end

0 commit comments

Comments
 (0)