Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 7497aa4

Browse files
committed
Make explicit imports for extensions
1 parent 26b0ec5 commit 7497aa4

8 files changed

+66
-49
lines changed

ext/SimpleNonlinearSolveChainRulesCoreExt.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
module SimpleNonlinearSolveChainRulesCoreExt
22

3-
using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
3+
using ChainRulesCore: ChainRulesCore, NoTangent
4+
using DiffEqBase: DiffEqBase
5+
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
6+
using SimpleNonlinearSolve: SimpleNonlinearSolve
47

58
# The expectation here is that no-one is using this directly inside a GPU kernel. We can
69
# eventually lift this requirement using a custom adjoint
710
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
8-
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...;
9-
kwargs...)
11+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0,
12+
u0_changed, p, p_changed, alg, args...; kwargs...)
1013
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
11-
SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...)
14+
ChainRulesOriginator(), alg, args...; kwargs...)
1215
function ∇__internal_solve_up(Δ)
1316
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
1417
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(),

ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SimpleNonlinearSolvePolyesterForwardDiffExt
22

3-
using SimpleNonlinearSolve, PolyesterForwardDiff
3+
using PolyesterForwardDiff: PolyesterForwardDiff
4+
using SimpleNonlinearSolve: SimpleNonlinearSolve
45

56
@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true
67

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,68 @@
11
module SimpleNonlinearSolveReverseDiffExt
22

3-
using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve
4-
import ReverseDiff: TrackedArray, TrackedReal
5-
import SimpleNonlinearSolve: __internal_solve_up
3+
using ArrayInterface: ArrayInterface
4+
using DiffEqBase: DiffEqBase
5+
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
6+
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
7+
using SimpleNonlinearSolve: SimpleNonlinearSolve
68

7-
function __internal_solve_up(
8-
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
9-
p::TrackedArray, p_changed, alg, args...; kwargs...)
10-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
9+
function SimpleNonlinearSolve.__internal_solve_up(
10+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
11+
u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
12+
return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0,
1113
u0_changed, p, p_changed, alg, args...; kwargs...)
1214
end
1315

14-
function __internal_solve_up(
15-
prob::NonlinearProblem, sensealg, u0, u0_changed,
16+
function SimpleNonlinearSolve.__internal_solve_up(
17+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed,
1618
p::TrackedArray, p_changed, alg, args...; kwargs...)
17-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
19+
return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0,
1820
u0_changed, p, p_changed, alg, args...; kwargs...)
1921
end
2022

21-
function __internal_solve_up(
22-
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
23-
p, p_changed, alg, args...; kwargs...)
24-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
23+
function SimpleNonlinearSolve.__internal_solve_up(
24+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
25+
u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...)
26+
return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0,
2527
u0_changed, p, p_changed, alg, args...; kwargs...)
2628
end
2729

28-
function __internal_solve_up(prob::NonlinearProblem, sensealg,
30+
function SimpleNonlinearSolve.__internal_solve_up(
31+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
2932
u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal},
3033
p_changed, alg, args...; kwargs...)
31-
return __internal_solve_up(
34+
return SimpleNonlinearSolve.__internal_solve_up(
3235
prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
3336
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
3437
end
3538

36-
function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed,
39+
function SimpleNonlinearSolve.__internal_solve_up(
40+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed,
3741
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
38-
return __internal_solve_up(
39-
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
42+
return SimpleNonlinearSolve.__internal_solve_up(
43+
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...;
44+
kwargs...)
4045
end
4146

42-
function __internal_solve_up(prob::NonlinearProblem, sensealg,
47+
function SimpleNonlinearSolve.__internal_solve_up(
48+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
4349
u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...)
44-
return __internal_solve_up(
45-
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
50+
return SimpleNonlinearSolve.__internal_solve_up(
51+
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...;
52+
kwargs...)
4653
end
4754

48-
ReverseDiff.@grad function __internal_solve_up(
49-
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
55+
ReverseDiff.@grad function SimpleNonlinearSolve.__internal_solve_up(
56+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0,
57+
u0_changed, p, p_changed, alg, args...; kwargs...)
5058
out, ∇internal = DiffEqBase._solve_adjoint(
5159
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
52-
SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...)
53-
function ∇__internal_solve_up(_args...)
60+
ReverseDiffOriginator(), alg, args...; kwargs...)
61+
functionSimpleNonlinearSolve.__internal_solve_up(_args...)
5462
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
5563
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
5664
end
57-
return Array(out), ∇__internal_solve_up
65+
return Array(out), ∇SimpleNonlinearSolve.__internal_solve_up
5866
end
5967

6068
end

ext/SimpleNonlinearSolveStaticArraysExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module SimpleNonlinearSolveStaticArraysExt
22

3-
using SimpleNonlinearSolve
3+
using SimpleNonlinearSolve: SimpleNonlinearSolve
44

55
@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true
66

ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,36 @@
11
module SimpleNonlinearSolveTrackerExt
22

3-
using DiffEqBase, SciMLBase, SimpleNonlinearSolve, Tracker
3+
using DiffEqBase: DiffEqBase
4+
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
5+
using SimpleNonlinearSolve: SimpleNonlinearSolve
6+
using Tracker: Tracker, TrackedArray
47

5-
function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
8+
function SimpleNonlinearSolve.__internal_solve_up(
9+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
610
sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...)
711
return Tracker.track(
812
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
913
p, p_changed, alg, args...; kwargs...)
1014
end
1115

1216
function SimpleNonlinearSolve.__internal_solve_up(
13-
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
14-
p::TrackedArray, p_changed, alg, args...; kwargs...)
17+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg,
18+
u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
1519
return Tracker.track(
1620
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
1721
p, p_changed, alg, args...; kwargs...)
1822
end
1923

20-
function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
24+
function SimpleNonlinearSolve.__internal_solve_up(
25+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
2126
sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
2227
return Tracker.track(
2328
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
2429
p, p_changed, alg, args...; kwargs...)
2530
end
2631

27-
Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(_prob::NonlinearProblem,
32+
Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(
33+
_prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
2834
sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...)
2935
u0, p = Tracker.data(u0_), Tracker.data(p_)
3036
prob = remake(_prob; u0, p)

ext/SimpleNonlinearSolveZygoteExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SimpleNonlinearSolveZygoteExt
22

3-
import SimpleNonlinearSolve, Zygote
3+
using SimpleNonlinearSolve: SimpleNonlinearSolve
4+
using Zygote: Zygote
45

56
SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true
67

src/SimpleNonlinearSolve.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
88

99
import DiffEqBase: AbstractNonlinearTerminationMode,
1010
AbstractSafeNonlinearTerminationMode,
11-
AbstractSafeBestNonlinearTerminationMode,
12-
NonlinearSafeTerminationReturnCode, get_termination_mode,
13-
NONLINEARSOLVE_DEFAULT_NORM
11+
AbstractSafeBestNonlinearTerminationMode, NONLINEARSOLVE_DEFAULT_NORM
1412
import DiffResults
1513
import ForwardDiff: Dual
1614
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
1715
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
18-
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size
16+
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
1917
end
2018

2119
@reexport using ADTypes, SciMLBase

src/utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ except `cache` (& `J` if not nothing) are mutated.
7777
function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F, X}
7878
if isinplace(f)
7979
_f = (du, u) -> f(du, u, p)
80-
if DiffEqBase.has_jac(f)
80+
if SciMLBase.has_jac(f)
8181
f.jac(J, x, p)
8282
_f(y, x)
8383
return y, J
@@ -97,7 +97,7 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
9797
end
9898
else
9999
_f = Base.Fix2(f, p)
100-
if DiffEqBase.has_jac(f)
100+
if SciMLBase.has_jac(f)
101101
return _f(x), f.jac(x, p)
102102
elseif ad isa AutoForwardDiff
103103
if ArrayInterface.can_setindex(x)
@@ -124,7 +124,7 @@ end
124124
function __polyester_forwarddiff_jacobian! end
125125

126126
function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where {F}
127-
if DiffEqBase.has_jac(f)
127+
if SciMLBase.has_jac(f)
128128
return f(x, p), f.jac(x, p)
129129
elseif ad isa AutoForwardDiff
130130
T = typeof(__standard_tag(ad.tag, x))
@@ -152,7 +152,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
152152
if isinplace(f)
153153
_f = (du, u) -> f(du, u, p)
154154
J = similar(y, length(y), length(x))
155-
if DiffEqBase.has_jac(f)
155+
if SciMLBase.has_jac(f)
156156
return J, nothing
157157
elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff
158158
return J, __get_jacobian_config(ad, _f, y, x)
@@ -163,7 +163,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
163163
end
164164
else
165165
_f = Base.Fix2(f, p)
166-
if DiffEqBase.has_jac(f)
166+
if SciMLBase.has_jac(f)
167167
return nothing, nothing
168168
elseif ad isa AutoForwardDiff
169169
J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing

0 commit comments

Comments
 (0)