Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 9984298

Browse files
committed
Resolve ambiguity
1 parent 8407682 commit 9984298

File tree

3 files changed

+88
-83
lines changed

3 files changed

+88
-83
lines changed

ext/SimpleNonlinearSolveReverseDiffExt.jl

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,59 +7,61 @@ using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresP
77
using SimpleNonlinearSolve: SimpleNonlinearSolve
88
import SimpleNonlinearSolve: __internal_solve_up
99

10-
function __internal_solve_up(
11-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
12-
u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
13-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
14-
u0_changed, p, p_changed, alg, args...; kwargs...)
15-
end
10+
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
11+
@eval begin
12+
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
13+
p::TrackedArray, p_changed, alg, args...; kwargs...)
14+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
15+
u0_changed, p, p_changed, alg, args...; kwargs...)
16+
end
1617

17-
function __internal_solve_up(
18-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
19-
u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
20-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
21-
u0_changed, p, p_changed, alg, args...; kwargs...)
22-
end
18+
function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed,
19+
p::TrackedArray, p_changed, alg, args...; kwargs...)
20+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
21+
u0_changed, p, p_changed, alg, args...; kwargs...)
22+
end
2323

24-
function __internal_solve_up(
25-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
26-
u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...)
27-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
28-
u0_changed, p, p_changed, alg, args...; kwargs...)
29-
end
24+
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray,
25+
u0_changed, p, p_changed, alg, args...; kwargs...)
26+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
27+
u0_changed, p, p_changed, alg, args...; kwargs...)
28+
end
3029

31-
function __internal_solve_up(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
32-
sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed,
33-
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
34-
return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
35-
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
36-
end
30+
function __internal_solve_up(
31+
prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed,
32+
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
33+
return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
34+
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
35+
end
3736

38-
function __internal_solve_up(
39-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0,
40-
u0_changed, p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
41-
return __internal_solve_up(prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
42-
true, alg, args...; kwargs...)
43-
end
37+
function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed,
38+
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
39+
return __internal_solve_up(
40+
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
41+
true, alg, args...; kwargs...)
42+
end
4443

45-
function __internal_solve_up(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
46-
sensealg, u0::AbstractArray{<:TrackedReal},
47-
u0_changed, p, p_changed, alg, args...; kwargs...)
48-
return __internal_solve_up(prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
49-
true, alg, args...; kwargs...)
50-
end
44+
function __internal_solve_up(
45+
prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal},
46+
u0_changed, p, p_changed, alg, args...; kwargs...)
47+
return __internal_solve_up(
48+
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
49+
true, alg, args...; kwargs...)
50+
end
5151

52-
ReverseDiff.@grad function __internal_solve_up(
53-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
54-
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
55-
out, ∇internal = DiffEqBase._solve_adjoint(
56-
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
57-
ReverseDiffOriginator(), alg, args...; kwargs...)
58-
function ∇__internal_solve_up(_args...)
59-
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
60-
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
52+
ReverseDiff.@grad function __internal_solve_up(
53+
prob::$(pType), sensealg, u0, u0_changed,
54+
p, p_changed, alg, args...; kwargs...)
55+
out, ∇internal = DiffEqBase._solve_adjoint(
56+
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
57+
ReverseDiffOriginator(), alg, args...; kwargs...)
58+
function ∇__internal_solve_up(_args...)
59+
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
60+
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
61+
end
62+
return Array(out), ∇__internal_solve_up
63+
end
6164
end
62-
return Array(out), ∇__internal_solve_up
6365
end
6466

6567
end

ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,49 @@
11
module SimpleNonlinearSolveTrackerExt
22

33
using DiffEqBase: DiffEqBase
4-
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
4+
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
55
using SimpleNonlinearSolve: SimpleNonlinearSolve
66
using Tracker: Tracker, TrackedArray
77

8-
function SimpleNonlinearSolve.__internal_solve_up(
9-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
10-
u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...)
11-
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
12-
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
13-
end
14-
15-
function SimpleNonlinearSolve.__internal_solve_up(
16-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
17-
u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
18-
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
19-
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
20-
end
21-
22-
function SimpleNonlinearSolve.__internal_solve_up(
23-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
24-
u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
25-
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
26-
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
27-
end
28-
29-
Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(
30-
_prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
31-
sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...)
32-
u0, p = Tracker.data(u0_), Tracker.data(p_)
33-
prob = remake(_prob; u0, p)
34-
out, ∇internal = DiffEqBase._solve_adjoint(
35-
prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...)
36-
37-
function ∇__internal_solve_up(Δ)
38-
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
39-
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
8+
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
9+
@eval begin
10+
function SimpleNonlinearSolve.__internal_solve_up(
11+
prob::$(pType), sensealg, u0::TrackedArray,
12+
u0_changed, p, p_changed, alg, args...; kwargs...)
13+
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
14+
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
15+
end
16+
17+
function SimpleNonlinearSolve.__internal_solve_up(
18+
prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
19+
p::TrackedArray, p_changed, alg, args...; kwargs...)
20+
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
21+
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
22+
end
23+
24+
function SimpleNonlinearSolve.__internal_solve_up(
25+
prob::$(pType), sensealg, u0, u0_changed,
26+
p::TrackedArray, p_changed, alg, args...; kwargs...)
27+
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
28+
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
29+
end
30+
31+
Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(
32+
_prob::$(pType), sensealg, u0_, u0_changed,
33+
p_, p_changed, alg, args...; kwargs...)
34+
u0, p = Tracker.data(u0_), Tracker.data(p_)
35+
prob = remake(_prob; u0, p)
36+
out, ∇internal = DiffEqBase._solve_adjoint(
37+
prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...)
38+
39+
function ∇__internal_solve_up(Δ)
40+
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
41+
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
42+
end
43+
44+
return out, ∇__internal_solve_up
45+
end
4046
end
41-
42-
return out, ∇__internal_solve_up
4347
end
4448

4549
end

test/core/exotic_type_tests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ end
1616
using SimpleNonlinearSolve, LinearAlgebra
1717

1818
for alg in [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(), SimpleDFSane(),
19-
SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2),
20-
SimpleHalley()]
19+
SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2), SimpleHalley()]
2120
sol = solve(prob_oop_bf, alg)
2221
@test norm(sol.resid, Inf) < 1e-6
2322
@test SciMLBase.successful_retcode(sol.retcode)

0 commit comments

Comments
 (0)