@@ -5,65 +5,61 @@ using DiffEqBase: DiffEqBase
5
5
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
6
6
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
7
7
using SimpleNonlinearSolve: SimpleNonlinearSolve
8
+ import SimpleNonlinearSolve: __internal_solve_up
8
9
9
- function SimpleNonlinearSolve . __internal_solve_up (
10
+ function __internal_solve_up (
10
11
prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
11
12
u0:: TrackedArray , u0_changed, p:: TrackedArray , p_changed, alg, args... ; kwargs... )
12
- return ReverseDiff. track (SimpleNonlinearSolve . __internal_solve_up, prob, sensealg,
13
- u0, u0_changed, p, p_changed, alg, args... ; kwargs... )
13
+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0 ,
14
+ u0_changed, p, p_changed, alg, args... ; kwargs... )
14
15
end
15
16
16
- function SimpleNonlinearSolve . __internal_solve_up (
17
+ function __internal_solve_up (
17
18
prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
18
19
u0, u0_changed, p:: TrackedArray , p_changed, alg, args... ; kwargs... )
19
- return ReverseDiff. track (SimpleNonlinearSolve . __internal_solve_up, prob, sensealg,
20
- u0, u0_changed, p, p_changed, alg, args... ; kwargs... )
20
+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0 ,
21
+ u0_changed, p, p_changed, alg, args... ; kwargs... )
21
22
end
22
23
23
- function SimpleNonlinearSolve . __internal_solve_up (
24
+ function __internal_solve_up (
24
25
prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
25
26
u0:: TrackedArray , u0_changed, p, p_changed, alg, args... ; kwargs... )
26
- return ReverseDiff. track (SimpleNonlinearSolve . __internal_solve_up, prob, sensealg,
27
- u0, u0_changed, p, p_changed, alg, args... ; kwargs... )
27
+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0 ,
28
+ u0_changed, p, p_changed, alg, args... ; kwargs... )
28
29
end
29
30
30
- function SimpleNonlinearSolve. __internal_solve_up (
31
- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
31
+ function __internal_solve_up (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
32
32
sensealg, u0:: AbstractArray{<:TrackedReal} , u0_changed,
33
33
p:: AbstractArray{<:TrackedReal} , p_changed, alg, args... ; kwargs... )
34
- return SimpleNonlinearSolve. __internal_solve_up (
35
- prob, sensealg, ArrayInterface. aos_to_soa (u0), true ,
34
+ return __internal_solve_up (prob, sensealg, ArrayInterface. aos_to_soa (u0), true ,
36
35
ArrayInterface. aos_to_soa (p), true , alg, args... ; kwargs... )
37
36
end
38
37
39
- function SimpleNonlinearSolve . __internal_solve_up (
38
+ function __internal_solve_up (
40
39
prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg, u0,
41
40
u0_changed, p:: AbstractArray{<:TrackedReal} , p_changed, alg, args... ; kwargs... )
42
- return SimpleNonlinearSolve. __internal_solve_up (
43
- prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
41
+ return __internal_solve_up (prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
44
42
true , alg, args... ; kwargs... )
45
43
end
46
44
47
- function SimpleNonlinearSolve. __internal_solve_up (
48
- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
45
+ function __internal_solve_up (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
49
46
sensealg, u0:: AbstractArray{<:TrackedReal} ,
50
47
u0_changed, p, p_changed, alg, args... ; kwargs... )
51
- return SimpleNonlinearSolve. __internal_solve_up (
52
- prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
48
+ return __internal_solve_up (prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
53
49
true , alg, args... ; kwargs... )
54
50
end
55
51
56
- ReverseDiff. @grad function SimpleNonlinearSolve . __internal_solve_up (
52
+ ReverseDiff. @grad function __internal_solve_up (
57
53
prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
58
54
sensealg, u0, u0_changed, p, p_changed, alg, args... ; kwargs... )
59
55
out, ∇internal = DiffEqBase. _solve_adjoint (
60
56
prob, sensealg, ReverseDiff. value (u0), ReverseDiff. value (p),
61
57
ReverseDiffOriginator (), alg, args... ; kwargs... )
62
- function ∇SimpleNonlinearSolve . __internal_solve_up (_args... )
58
+ function ∇__internal_solve_up (_args... )
63
59
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal (_args... )
64
60
return (∂prob, ∂sensealg, ∂u0, nothing , ∂p, nothing , nothing , ∂args... )
65
61
end
66
- return Array (out), ∇SimpleNonlinearSolve . __internal_solve_up
62
+ return Array (out), ∇__internal_solve_up
67
63
end
68
64
69
65
end
0 commit comments