Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 7e2d867

Browse files
authored
Merge pull request #147 from SciML/ap/explicit_imports
Improve Code Standards
2 parents 26b0ec5 + 9984298 commit 7e2d867

34 files changed

+436
-397
lines changed

.JuliaFormatter.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
style = "sciml"
22
format_markdown = true
33
annotate_untyped_fields_with_any = false
4-
format_docstrings = true
4+
format_docstrings = true
5+
join_lines_based_on_source = false

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ ChainRulesCore = "1.22"
4545
ConcreteStructs = "0.2.3"
4646
DiffEqBase = "6.149"
4747
DiffResults = "1.1"
48+
ExplicitImports = "1.5.0"
4849
FastClosures = "0.3.2"
4950
FiniteDiff = "2.22"
5051
ForwardDiff = "0.10.36"
@@ -73,6 +74,7 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
7374
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7475
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
7576
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
77+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
7678
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
7779
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7880
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -91,4 +93,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9193
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9294

9395
[targets]
94-
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff", "ReverseDiff", "Tracker"]
96+
test = ["AllocCheck", "Aqua", "CUDA", "DiffEqBase", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "Reexport", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "Tracker", "Zygote"]

ext/SimpleNonlinearSolveChainRulesCoreExt.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
module SimpleNonlinearSolveChainRulesCoreExt
22

3-
using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
3+
using ChainRulesCore: ChainRulesCore, NoTangent
4+
using DiffEqBase: DiffEqBase
5+
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
6+
using SimpleNonlinearSolve: SimpleNonlinearSolve
47

58
# The expectation here is that no-one is using this directly inside a GPU kernel. We can
69
# eventually lift this requirement using a custom adjoint
710
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
8-
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...;
9-
kwargs...)
10-
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
11-
SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...)
11+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
12+
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
13+
out, ∇internal = DiffEqBase._solve_adjoint(
14+
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
1215
function ∇__internal_solve_up(Δ)
1316
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
14-
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(),
15-
∂args...)
17+
return (
18+
f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...)
1619
end
1720
return out, ∇__internal_solve_up
1821
end

ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
module SimpleNonlinearSolvePolyesterForwardDiffExt
22

3-
using SimpleNonlinearSolve, PolyesterForwardDiff
3+
using PolyesterForwardDiff: PolyesterForwardDiff
4+
using SimpleNonlinearSolve: SimpleNonlinearSolve
45

56
@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true
67

7-
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f!::F, y, J, x,
8-
chunksize) where {F}
8+
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(
9+
f!::F, y, J, x, chunksize) where {F}
910
PolyesterForwardDiff.threaded_jacobian!(f!, y, J, x, chunksize)
1011
return J
1112
end
1213

13-
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x,
14-
chunksize) where {F}
14+
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(
15+
f::F, J, x, chunksize) where {F}
1516
PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize)
1617
return J
1718
end
Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,67 @@
11
module SimpleNonlinearSolveReverseDiffExt
22

3-
using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve
4-
import ReverseDiff: TrackedArray, TrackedReal
3+
using ArrayInterface: ArrayInterface
4+
using DiffEqBase: DiffEqBase
5+
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
6+
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
7+
using SimpleNonlinearSolve: SimpleNonlinearSolve
58
import SimpleNonlinearSolve: __internal_solve_up
69

7-
function __internal_solve_up(
8-
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
9-
p::TrackedArray, p_changed, alg, args...; kwargs...)
10-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
11-
u0_changed, p, p_changed, alg, args...; kwargs...)
12-
end
10+
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
11+
@eval begin
12+
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
13+
p::TrackedArray, p_changed, alg, args...; kwargs...)
14+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
15+
u0_changed, p, p_changed, alg, args...; kwargs...)
16+
end
1317

14-
function __internal_solve_up(
15-
prob::NonlinearProblem, sensealg, u0, u0_changed,
16-
p::TrackedArray, p_changed, alg, args...; kwargs...)
17-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
18-
u0_changed, p, p_changed, alg, args...; kwargs...)
19-
end
18+
function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed,
19+
p::TrackedArray, p_changed, alg, args...; kwargs...)
20+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
21+
u0_changed, p, p_changed, alg, args...; kwargs...)
22+
end
2023

21-
function __internal_solve_up(
22-
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
23-
p, p_changed, alg, args...; kwargs...)
24-
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
25-
u0_changed, p, p_changed, alg, args...; kwargs...)
26-
end
24+
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray,
25+
u0_changed, p, p_changed, alg, args...; kwargs...)
26+
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
27+
u0_changed, p, p_changed, alg, args...; kwargs...)
28+
end
2729

28-
function __internal_solve_up(prob::NonlinearProblem, sensealg,
29-
u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal},
30-
p_changed, alg, args...; kwargs...)
31-
return __internal_solve_up(
32-
prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
33-
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
34-
end
30+
function __internal_solve_up(
31+
prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed,
32+
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
33+
return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
34+
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
35+
end
3536

36-
function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed,
37-
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
38-
return __internal_solve_up(
39-
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
40-
end
37+
function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed,
38+
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
39+
return __internal_solve_up(
40+
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
41+
true, alg, args...; kwargs...)
42+
end
4143

42-
function __internal_solve_up(prob::NonlinearProblem, sensealg,
43-
u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...)
44-
return __internal_solve_up(
45-
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
46-
end
44+
function __internal_solve_up(
45+
prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal},
46+
u0_changed, p, p_changed, alg, args...; kwargs...)
47+
return __internal_solve_up(
48+
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
49+
true, alg, args...; kwargs...)
50+
end
4751

48-
ReverseDiff.@grad function __internal_solve_up(
49-
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
50-
out, ∇internal = DiffEqBase._solve_adjoint(
51-
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
52-
SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...)
53-
function ∇__internal_solve_up(_args...)
54-
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
55-
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
52+
ReverseDiff.@grad function __internal_solve_up(
53+
prob::$(pType), sensealg, u0, u0_changed,
54+
p, p_changed, alg, args...; kwargs...)
55+
out, ∇internal = DiffEqBase._solve_adjoint(
56+
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
57+
ReverseDiffOriginator(), alg, args...; kwargs...)
58+
function ∇__internal_solve_up(_args...)
59+
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
60+
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
61+
end
62+
return Array(out), ∇__internal_solve_up
63+
end
5664
end
57-
return Array(out), ∇__internal_solve_up
5865
end
5966

6067
end

ext/SimpleNonlinearSolveStaticArraysExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module SimpleNonlinearSolveStaticArraysExt
22

3-
using SimpleNonlinearSolve
3+
using SimpleNonlinearSolve: SimpleNonlinearSolve
44

55
@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true
66

ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,49 @@
11
module SimpleNonlinearSolveTrackerExt
22

3-
using DiffEqBase, SciMLBase, SimpleNonlinearSolve, Tracker
4-
5-
function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
6-
sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...)
7-
return Tracker.track(
8-
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
9-
p, p_changed, alg, args...; kwargs...)
10-
end
11-
12-
function SimpleNonlinearSolve.__internal_solve_up(
13-
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
14-
p::TrackedArray, p_changed, alg, args...; kwargs...)
15-
return Tracker.track(
16-
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
17-
p, p_changed, alg, args...; kwargs...)
18-
end
19-
20-
function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
21-
sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
22-
return Tracker.track(
23-
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
24-
p, p_changed, alg, args...; kwargs...)
25-
end
26-
27-
Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(_prob::NonlinearProblem,
28-
sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...)
29-
u0, p = Tracker.data(u0_), Tracker.data(p_)
30-
prob = remake(_prob; u0, p)
31-
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
32-
SciMLBase.TrackerOriginator(), alg, args...; kwargs...)
33-
34-
function ∇__internal_solve_up(Δ)
35-
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
36-
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
3+
using DiffEqBase: DiffEqBase
4+
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
5+
using SimpleNonlinearSolve: SimpleNonlinearSolve
6+
using Tracker: Tracker, TrackedArray
7+
8+
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
9+
@eval begin
10+
function SimpleNonlinearSolve.__internal_solve_up(
11+
prob::$(pType), sensealg, u0::TrackedArray,
12+
u0_changed, p, p_changed, alg, args...; kwargs...)
13+
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
14+
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
15+
end
16+
17+
function SimpleNonlinearSolve.__internal_solve_up(
18+
prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
19+
p::TrackedArray, p_changed, alg, args...; kwargs...)
20+
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
21+
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
22+
end
23+
24+
function SimpleNonlinearSolve.__internal_solve_up(
25+
prob::$(pType), sensealg, u0, u0_changed,
26+
p::TrackedArray, p_changed, alg, args...; kwargs...)
27+
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
28+
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
29+
end
30+
31+
Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(
32+
_prob::$(pType), sensealg, u0_, u0_changed,
33+
p_, p_changed, alg, args...; kwargs...)
34+
u0, p = Tracker.data(u0_), Tracker.data(p_)
35+
prob = remake(_prob; u0, p)
36+
out, ∇internal = DiffEqBase._solve_adjoint(
37+
prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...)
38+
39+
function ∇__internal_solve_up(Δ)
40+
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
41+
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
42+
end
43+
44+
return out, ∇__internal_solve_up
45+
end
3746
end
38-
39-
return out, ∇__internal_solve_up
4047
end
4148

4249
end

ext/SimpleNonlinearSolveZygoteExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SimpleNonlinearSolveZygoteExt
22

3-
import SimpleNonlinearSolve, Zygote
3+
using SimpleNonlinearSolve: SimpleNonlinearSolve
4+
using Zygote: Zygote
45

56
SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true
67

0 commit comments

Comments
 (0)