Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit ba3b5a4

Browse files
Merge pull request #123 from avik-pal/ap/fix_adjoint
Patch Adjoint Sensitivity for Simple Nonlinear Solve Algorithms
2 parents f8408d4 + 17220bb commit ba3b5a4

File tree

6 files changed

+58
-9
lines changed

6 files changed

+58
-9
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.3.1"
4+
version = "1.3.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -19,16 +19,19 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1919
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2020

2121
[weakdeps]
22+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2223
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2324
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2425

2526
[extensions]
27+
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
2628
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
2729
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
2830

2931
[compat]
3032
ADTypes = "0.2.6"
3133
ArrayInterface = "7"
34+
ChainRulesCore = "1"
3235
ConcreteStructs = "0.2"
3336
DiffEqBase = "6.126"
3437
FastClosures = "0.3"
@@ -39,6 +42,6 @@ MaybeInplace = "0.1"
3942
PrecompileTools = "1"
4043
Reexport = "1"
4144
SciMLBase = "2.7"
42-
StaticArraysCore = "1.4"
4345
StaticArrays = "1"
46+
StaticArraysCore = "1.4"
4447
julia = "1.9"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module SimpleNonlinearSolveChainRulesCoreExt
2+
3+
using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
4+
5+
# The expectation here is that no-one is using this directly inside a GPU kernel. We can
6+
# eventually lift this requirement using a custom adjoint
7+
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
8+
prob::NonlinearProblem,
9+
sensealg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, u0_changed,
10+
p, p_changed, alg, args...; kwargs...)
11+
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
12+
SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...)
13+
function ∇__internal_solve_up(Δ)
14+
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
15+
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), ∂originator,
16+
∂args...)
17+
end
18+
return out, ∇__internal_solve_up
19+
end
20+
21+
end

src/SimpleNonlinearSolve.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,26 @@ include("ad.jl")
5050
## Default algorithm
5151

5252
# Set the default bracketing method to ITP
53-
function SciMLBase.solve(prob::IntervalNonlinearProblem; kwargs...)
54-
return solve(prob, ITP(); kwargs...)
55-
end
56-
57-
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing,
58-
args...; kwargs...)
53+
SciMLBase.solve(prob::IntervalNonlinearProblem; kwargs...) = solve(prob, ITP(); kwargs...)
54+
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...; kwargs...)
5955
return solve(prob, ITP(), args...; kwargs...)
6056
end
6157

6258
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
6359
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
64-
args...; kwargs...)
60+
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
61+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
62+
sensealg = prob.kwargs[:sensealg]
63+
end
64+
new_u0 = u0 !== nothing ? u0 : prob.u0
65+
new_p = p !== nothing ? p : prob.p
66+
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
67+
alg, args...; kwargs...)
68+
end
69+
70+
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, p,
71+
p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
72+
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
6573
return SciMLBase.__solve(prob, alg, args...; kwargs...)
6674
end
6775

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
99
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
12+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1314
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1416

1517
[compat]
1618
NonlinearProblemLibrary = "0.1.2"

test/adjoint.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using ForwardDiff, SciMLSensitivity, SimpleNonlinearSolve, Test, Zygote
2+
3+
@testset "Simple Adjoint Test" begin
4+
ff(u, p) = u .^ 2 .- p
5+
6+
function solve_nlprob(p)
7+
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
8+
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u)
9+
end
10+
11+
p = [3.0, 2.0]
12+
13+
@test only(Zygote.gradient(solve_nlprob, p)) ForwardDiff.gradient(solve_nlprob, p)
14+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ end
1515
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
1616
@time @safetestset "Least Squares Tests" include("least_squares.jl")
1717
@time @safetestset "23 Test Problems" include("23_test_problems.jl")
18+
@time @safetestset "Simple Adjoint Tests" include("adjoint.jl")
1819
end
1920

2021
if GROUP == "CUDA"

0 commit comments

Comments
 (0)