@@ -7,59 +7,61 @@ using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresP
7
7
using SimpleNonlinearSolve: SimpleNonlinearSolve
8
8
import SimpleNonlinearSolve: __internal_solve_up
9
9
10
- function __internal_solve_up (
11
- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
12
- u0:: TrackedArray , u0_changed, p:: TrackedArray , p_changed, alg, args... ; kwargs... )
13
- return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
14
- u0_changed, p, p_changed, alg, args... ; kwargs... )
15
- end
10
+ for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
11
+ @eval begin
12
+ function __internal_solve_up (prob:: $ (pType), sensealg, u0:: TrackedArray , u0_changed,
13
+ p:: TrackedArray , p_changed, alg, args... ; kwargs... )
14
+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
15
+ u0_changed, p, p_changed, alg, args... ; kwargs... )
16
+ end
16
17
17
- function __internal_solve_up (
18
- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
19
- u0, u0_changed, p:: TrackedArray , p_changed, alg, args... ; kwargs... )
20
- return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
21
- u0_changed, p, p_changed, alg, args... ; kwargs... )
22
- end
18
+ function __internal_solve_up (prob:: $ (pType), sensealg, u0, u0_changed,
19
+ p:: TrackedArray , p_changed, alg, args... ; kwargs... )
20
+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
21
+ u0_changed, p, p_changed, alg, args... ; kwargs... )
22
+ end
23
23
24
- function __internal_solve_up (
25
- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
26
- u0:: TrackedArray , u0_changed, p, p_changed, alg, args... ; kwargs... )
27
- return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
28
- u0_changed, p, p_changed, alg, args... ; kwargs... )
29
- end
24
+ function __internal_solve_up (prob:: $ (pType), sensealg, u0:: TrackedArray ,
25
+ u0_changed, p, p_changed, alg, args... ; kwargs... )
26
+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
27
+ u0_changed, p, p_changed, alg, args... ; kwargs... )
28
+ end
30
29
31
- function __internal_solve_up (prob :: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
32
- sensealg, u0:: AbstractArray{<:TrackedReal} , u0_changed,
33
- p:: AbstractArray{<:TrackedReal} , p_changed, alg, args... ; kwargs... )
34
- return __internal_solve_up (prob, sensealg, ArrayInterface. aos_to_soa (u0), true ,
35
- ArrayInterface. aos_to_soa (p), true , alg, args... ; kwargs... )
36
- end
30
+ function __internal_solve_up (
31
+ prob :: $ (pType), sensealg, u0:: AbstractArray{<:TrackedReal} , u0_changed,
32
+ p:: AbstractArray{<:TrackedReal} , p_changed, alg, args... ; kwargs... )
33
+ return __internal_solve_up (prob, sensealg, ArrayInterface. aos_to_soa (u0), true ,
34
+ ArrayInterface. aos_to_soa (p), true , alg, args... ; kwargs... )
35
+ end
37
36
38
- function __internal_solve_up (
39
- prob :: Union{NonlinearProblem, NonlinearLeastSquaresProblem } , sensealg, u0,
40
- u0_changed, p :: AbstractArray{<:TrackedReal} , p_changed, alg, args ... ; kwargs ... )
41
- return __internal_solve_up ( prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
42
- true , alg, args... ; kwargs... )
43
- end
37
+ function __internal_solve_up (prob :: $ (pType), sensealg, u0, u0_changed,
38
+ p :: AbstractArray{<:TrackedReal } , p_changed, alg, args ... ; kwargs ... )
39
+ return __internal_solve_up (
40
+ prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
41
+ true , alg, args... ; kwargs... )
42
+ end
44
43
45
- function __internal_solve_up (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
46
- sensealg, u0:: AbstractArray{<:TrackedReal} ,
47
- u0_changed, p, p_changed, alg, args... ; kwargs... )
48
- return __internal_solve_up (prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
49
- true , alg, args... ; kwargs... )
50
- end
44
+ function __internal_solve_up (
45
+ prob:: $ (pType), sensealg, u0:: AbstractArray{<:TrackedReal} ,
46
+ u0_changed, p, p_changed, alg, args... ; kwargs... )
47
+ return __internal_solve_up (
48
+ prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
49
+ true , alg, args... ; kwargs... )
50
+ end
51
51
52
- ReverseDiff. @grad function __internal_solve_up (
53
- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
54
- sensealg, u0, u0_changed, p, p_changed, alg, args... ; kwargs... )
55
- out, ∇internal = DiffEqBase. _solve_adjoint (
56
- prob, sensealg, ReverseDiff. value (u0), ReverseDiff. value (p),
57
- ReverseDiffOriginator (), alg, args... ; kwargs... )
58
- function ∇__internal_solve_up (_args... )
59
- ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal (_args... )
60
- return (∂prob, ∂sensealg, ∂u0, nothing , ∂p, nothing , nothing , ∂args... )
52
+ ReverseDiff. @grad function __internal_solve_up (
53
+ prob:: $ (pType), sensealg, u0, u0_changed,
54
+ p, p_changed, alg, args... ; kwargs... )
55
+ out, ∇internal = DiffEqBase. _solve_adjoint (
56
+ prob, sensealg, ReverseDiff. value (u0), ReverseDiff. value (p),
57
+ ReverseDiffOriginator (), alg, args... ; kwargs... )
58
+ function ∇__internal_solve_up (_args... )
59
+ ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal (_args... )
60
+ return (∂prob, ∂sensealg, ∂u0, nothing , ∂p, nothing , nothing , ∂args... )
61
+ end
62
+ return Array (out), ∇__internal_solve_up
63
+ end
61
64
end
62
- return Array (out), ∇__internal_solve_up
63
65
end
64
66
65
67
end
0 commit comments