|
1 | 1 | module SimpleNonlinearSolveReverseDiffExt
|
2 | 2 |
|
3 |
| -using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve |
4 |
| -import ReverseDiff: TrackedArray, TrackedReal |
| 3 | +using ArrayInterface: ArrayInterface |
| 4 | +using DiffEqBase: DiffEqBase |
| 5 | +using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal |
| 6 | +using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem |
| 7 | +using SimpleNonlinearSolve: SimpleNonlinearSolve |
5 | 8 | import SimpleNonlinearSolve: __internal_solve_up
|
6 | 9 |
|
7 |
| -function __internal_solve_up( |
8 |
| - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, |
9 |
| - p::TrackedArray, p_changed, alg, args...; kwargs...) |
10 |
| - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
11 |
| - u0_changed, p, p_changed, alg, args...; kwargs...) |
12 |
| -end |
| 10 | +for pType in (NonlinearProblem, NonlinearLeastSquaresProblem) |
| 11 | + @eval begin |
| 12 | + function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed, |
| 13 | + p::TrackedArray, p_changed, alg, args...; kwargs...) |
| 14 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 15 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 16 | + end |
13 | 17 |
|
14 |
| -function __internal_solve_up( |
15 |
| - prob::NonlinearProblem, sensealg, u0, u0_changed, |
16 |
| - p::TrackedArray, p_changed, alg, args...; kwargs...) |
17 |
| - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
18 |
| - u0_changed, p, p_changed, alg, args...; kwargs...) |
19 |
| -end |
| 18 | + function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed, |
| 19 | + p::TrackedArray, p_changed, alg, args...; kwargs...) |
| 20 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 21 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 22 | + end |
20 | 23 |
|
21 |
| -function __internal_solve_up( |
22 |
| - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, |
23 |
| - p, p_changed, alg, args...; kwargs...) |
24 |
| - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
25 |
| - u0_changed, p, p_changed, alg, args...; kwargs...) |
26 |
| -end |
| 24 | + function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, |
| 25 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 26 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 27 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 28 | + end |
27 | 29 |
|
28 |
| -function __internal_solve_up(prob::NonlinearProblem, sensealg, |
29 |
| - u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal}, |
30 |
| - p_changed, alg, args...; kwargs...) |
31 |
| - return __internal_solve_up( |
32 |
| - prob, sensealg, ArrayInterface.aos_to_soa(u0), true, |
33 |
| - ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
34 |
| -end |
| 30 | + function __internal_solve_up( |
| 31 | + prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, |
| 32 | + p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) |
| 33 | + return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true, |
| 34 | + ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
| 35 | + end |
35 | 36 |
|
36 |
| -function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed, |
37 |
| - p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) |
38 |
| - return __internal_solve_up( |
39 |
| - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
40 |
| -end |
| 37 | + function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed, |
| 38 | + p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) |
| 39 | + return __internal_solve_up( |
| 40 | + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), |
| 41 | + true, alg, args...; kwargs...) |
| 42 | + end |
41 | 43 |
|
42 |
| -function __internal_solve_up(prob::NonlinearProblem, sensealg, |
43 |
| - u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...) |
44 |
| - return __internal_solve_up( |
45 |
| - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
46 |
| -end |
| 44 | + function __internal_solve_up( |
| 45 | + prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, |
| 46 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 47 | + return __internal_solve_up( |
| 48 | + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), |
| 49 | + true, alg, args...; kwargs...) |
| 50 | + end |
47 | 51 |
|
48 |
| -ReverseDiff.@grad function __internal_solve_up( |
49 |
| - prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) |
50 |
| - out, ∇internal = DiffEqBase._solve_adjoint( |
51 |
| - prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), |
52 |
| - SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...) |
53 |
| - function ∇__internal_solve_up(_args...) |
54 |
| - ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) |
55 |
| - return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) |
| 52 | + ReverseDiff.@grad function __internal_solve_up( |
| 53 | + prob::$(pType), sensealg, u0, u0_changed, |
| 54 | + p, p_changed, alg, args...; kwargs...) |
| 55 | + out, ∇internal = DiffEqBase._solve_adjoint( |
| 56 | + prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), |
| 57 | + ReverseDiffOriginator(), alg, args...; kwargs...) |
| 58 | + function ∇__internal_solve_up(_args...) |
| 59 | + ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) |
| 60 | + return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) |
| 61 | + end |
| 62 | + return Array(out), ∇__internal_solve_up |
| 63 | + end |
56 | 64 | end
|
57 |
| - return Array(out), ∇__internal_solve_up |
58 | 65 | end
|
59 | 66 |
|
60 | 67 | end
|
0 commit comments