Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 5e0f188

Browse files
authored
Merge pull request #133 from SciML/ap/reversediff
Automatic Differentiation
2 parents fd7d216 + 939c93b commit 5e0f188

8 files changed

+132
-25
lines changed

Project.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.5.0"
4+
version = "1.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -22,20 +22,24 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2222
[weakdeps]
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2424
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
25+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2526
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
27+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2628
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2729

2830
[extensions]
2931
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
3032
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
33+
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
3134
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
35+
SimpleNonlinearSolveTrackerExt = "Tracker"
3236
SimpleNonlinearSolveZygoteExt = "Zygote"
3337

3438
[compat]
3539
ADTypes = "0.2.6"
3640
AllocCheck = "0.1.1"
37-
ArrayInterface = "7.7"
3841
Aqua = "0.8"
42+
ArrayInterface = "7.7"
3943
CUDA = "5.2"
4044
ChainRulesCore = "1.22"
4145
ConcreteStructs = "0.2.3"
@@ -54,11 +58,13 @@ PrecompileTools = "1.2"
5458
Random = "1.10"
5559
ReTestItems = "1.23"
5660
Reexport = "1.2"
61+
ReverseDiff = "1.15"
5762
SciMLBase = "2.26.3"
5863
SciMLSensitivity = "7.56"
5964
StaticArrays = "1.9"
6065
StaticArraysCore = "1.4.2"
6166
Test = "1.10"
67+
Tracker = "0.2.32"
6268
Zygote = "0.6.69"
6369
julia = "1.10"
6470

@@ -77,10 +83,12 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
7783
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7884
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
7985
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
86+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
8087
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
8188
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
8289
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
90+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
8391
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8492

8593
[targets]
86-
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]
94+
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff", "ReverseDiff", "Tracker"]

ext/SimpleNonlinearSolveChainRulesCoreExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@ using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
55
# The expectation here is that no-one is using this directly inside a GPU kernel. We can
66
# eventually lift this requirement using a custom adjoint
77
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
8-
prob::NonlinearProblem,
9-
sensealg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, u0_changed,
10-
p, p_changed, alg, args...; kwargs...)
8+
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...;
9+
kwargs...)
1110
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
1211
SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...)
1312
function ∇__internal_solve_up(Δ)
1413
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
15-
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), ∂originator,
14+
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(),
1615
∂args...)
1716
end
1817
return out, ∇__internal_solve_up
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module SimpleNonlinearSolveReverseDiffExt
2+
3+
using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve
4+
import ReverseDiff: TrackedArray, TrackedReal
5+
import SimpleNonlinearSolve: __internal_solve_up
6+
7+
function __internal_solve_up(
8+
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
9+
p::TrackedArray, p_changed, alg, args...; kwargs...)
10+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
11+
u0_changed, p, p_changed, alg, args...; kwargs...)
12+
end
13+
14+
function __internal_solve_up(
15+
prob::NonlinearProblem, sensealg, u0, u0_changed,
16+
p::TrackedArray, p_changed, alg, args...; kwargs...)
17+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
18+
u0_changed, p, p_changed, alg, args...; kwargs...)
19+
end
20+
21+
function __internal_solve_up(
22+
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
23+
p, p_changed, alg, args...; kwargs...)
24+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
25+
u0_changed, p, p_changed, alg, args...; kwargs...)
26+
end
27+
28+
function __internal_solve_up(prob::NonlinearProblem, sensealg,
29+
u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal},
30+
p_changed, alg, args...; kwargs...)
31+
return __internal_solve_up(
32+
prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
33+
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
34+
end
35+
36+
function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed,
37+
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
38+
return __internal_solve_up(
39+
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
40+
end
41+
42+
function __internal_solve_up(prob::NonlinearProblem, sensealg,
43+
u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...)
44+
return __internal_solve_up(
45+
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
46+
end
47+
48+
ReverseDiff.@grad function __internal_solve_up(
49+
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
50+
out, ∇internal = DiffEqBase._solve_adjoint(
51+
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
52+
SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...)
53+
function ∇__internal_solve_up(_args...)
54+
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
55+
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
56+
end
57+
return Array(out), ∇__internal_solve_up
58+
end
59+
60+
end

ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
module SimpleNonlinearSolveTrackerExt
2+
3+
using DiffEqBase, SciMLBase, SimpleNonlinearSolve, Tracker
4+
5+
function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
6+
sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...)
7+
return Tracker.track(
8+
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
9+
p, p_changed, alg, args...; kwargs...)
10+
end
11+
12+
function SimpleNonlinearSolve.__internal_solve_up(
13+
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
14+
p::TrackedArray, p_changed, alg, args...; kwargs...)
15+
return Tracker.track(
16+
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
17+
p, p_changed, alg, args...; kwargs...)
18+
end
19+
20+
function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
21+
sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
22+
return Tracker.track(
23+
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
24+
p, p_changed, alg, args...; kwargs...)
25+
end
26+
27+
Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(_prob::NonlinearProblem,
28+
sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...)
29+
u0, p = Tracker.data(u0_), Tracker.data(p_)
30+
prob = remake(_prob; u0, p)
31+
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
32+
SciMLBase.TrackerOriginator(), alg, args...; kwargs...)
33+
34+
function ∇__internal_solve_up(Δ)
35+
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
36+
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
37+
end
38+
39+
return out, ∇__internal_solve_up
40+
end
41+
42+
end

src/SimpleNonlinearSolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ function SciMLBase.solve(
7171
alg, args...; prob.kwargs..., kwargs...)
7272
end
7373

74-
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, p,
75-
p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
74+
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
75+
p, p_changed, alg, args...; kwargs...)
7676
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
7777
return SciMLBase.__solve(prob, alg, args...; kwargs...)
7878
end

src/nlsolve/halley.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
3333
fx = _get_fx(prob, x)
3434
T = eltype(x)
3535

36-
autodiff = __get_concrete_autodiff(prob, alg.autodiff; polyester = Val(false))
36+
autodiff = __get_concrete_autodiff(prob, alg.autodiff)
3737
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
3838
termination_condition)
3939

src/utils.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -365,18 +365,9 @@ end
365365

366366
# Decide which AD backend to use
367367
@inline __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType; kwargs...) = ad
368-
@inline function __get_concrete_autodiff(prob, ::Nothing; polyester::Val{P} = Val(true),
369-
kwargs...) where {P}
370-
if ForwardDiff.can_dual(eltype(prob.u0))
371-
if P && __is_extension_loaded(Val(:PolyesterForwardDiff)) &&
372-
!(prob.u0 isa Number) && ArrayInterface.can_setindex(prob.u0)
373-
return AutoPolyesterForwardDiff()
374-
else
375-
return AutoForwardDiff()
376-
end
377-
else
378-
return AutoFiniteDiff()
379-
end
368+
@inline function __get_concrete_autodiff(prob, ::Nothing; kwargs...)
369+
return ifelse(
370+
ForwardDiff.can_dual(eltype(prob.u0)), AutoForwardDiff(), AutoFiniteDiff())
380371
end
381372

382373
@inline __reshape(x::Number, args...) = x

test/core/adjoint_tests.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
@testitem "Simple Adjoint Test" begin
2-
using ForwardDiff, SciMLSensitivity, Zygote
2+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote
33

44
ff(u, p) = u .^ 2 .- p
55

66
function solve_nlprob(p)
77
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
8-
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u)
8+
sol = solve(prob, SimpleNewtonRaphson())
9+
res = sol isa AbstractArray ? sol : sol.u
10+
return sum(abs2, res)
911
end
1012

1113
p = [3.0, 2.0]
1214

13-
@test only(Zygote.gradient(solve_nlprob, p)) ForwardDiff.gradient(solve_nlprob, p)
15+
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
16+
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
17+
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
18+
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
19+
20+
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff
1421
end

0 commit comments

Comments
 (0)