|
1 | 1 | module SimpleNonlinearSolveReverseDiffExt
|
2 | 2 |
|
3 |
| -using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve |
4 |
| -import ReverseDiff: TrackedArray, TrackedReal |
5 |
| -import SimpleNonlinearSolve: __internal_solve_up |
| 3 | +using ArrayInterface: ArrayInterface |
| 4 | +using DiffEqBase: DiffEqBase |
| 5 | +using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal |
| 6 | +using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem |
| 7 | +using SimpleNonlinearSolve: SimpleNonlinearSolve |
6 | 8 |
|
7 |
| -function __internal_solve_up( |
8 |
| - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, |
9 |
| - p::TrackedArray, p_changed, alg, args...; kwargs...) |
10 |
| - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 9 | +function SimpleNonlinearSolve.__internal_solve_up( |
| 10 | + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, |
| 11 | + u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) |
| 12 | + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, |
11 | 13 | u0_changed, p, p_changed, alg, args...; kwargs...)
|
12 | 14 | end
|
13 | 15 |
|
14 |
| -function __internal_solve_up( |
15 |
| - prob::NonlinearProblem, sensealg, u0, u0_changed, |
| 16 | +function SimpleNonlinearSolve.__internal_solve_up( |
| 17 | + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, |
16 | 18 | p::TrackedArray, p_changed, alg, args...; kwargs...)
|
17 |
| - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 19 | + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, |
18 | 20 | u0_changed, p, p_changed, alg, args...; kwargs...)
|
19 | 21 | end
|
20 | 22 |
|
21 |
| -function __internal_solve_up( |
22 |
| - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, |
23 |
| - p, p_changed, alg, args...; kwargs...) |
24 |
| - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 23 | +function SimpleNonlinearSolve.__internal_solve_up( |
| 24 | + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, |
| 25 | + u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) |
| 26 | + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, |
25 | 27 | u0_changed, p, p_changed, alg, args...; kwargs...)
|
26 | 28 | end
|
27 | 29 |
|
28 |
| -function __internal_solve_up(prob::NonlinearProblem, sensealg, |
| 30 | +function SimpleNonlinearSolve.__internal_solve_up( |
| 31 | + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, |
29 | 32 | u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal},
|
30 | 33 | p_changed, alg, args...; kwargs...)
|
31 |
| - return __internal_solve_up( |
| 34 | + return SimpleNonlinearSolve.__internal_solve_up( |
32 | 35 | prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
|
33 | 36 | ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
|
34 | 37 | end
|
35 | 38 |
|
36 |
| -function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed, |
| 39 | +function SimpleNonlinearSolve.__internal_solve_up( |
| 40 | + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, |
37 | 41 | p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
|
38 |
| - return __internal_solve_up( |
39 |
| - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
| 42 | + return SimpleNonlinearSolve.__internal_solve_up( |
| 43 | + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; |
| 44 | + kwargs...) |
40 | 45 | end
|
41 | 46 |
|
42 |
| -function __internal_solve_up(prob::NonlinearProblem, sensealg, |
| 47 | +function SimpleNonlinearSolve.__internal_solve_up( |
| 48 | + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, |
43 | 49 | u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...)
|
44 |
| - return __internal_solve_up( |
45 |
| - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
| 50 | + return SimpleNonlinearSolve.__internal_solve_up( |
| 51 | + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; |
| 52 | + kwargs...) |
46 | 53 | end
|
47 | 54 |
|
48 |
| -ReverseDiff.@grad function __internal_solve_up( |
49 |
| - prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) |
| 55 | +ReverseDiff.@grad function SimpleNonlinearSolve.__internal_solve_up( |
| 56 | + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, |
| 57 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
50 | 58 | out, ∇internal = DiffEqBase._solve_adjoint(
|
51 | 59 | prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
|
52 |
| - SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...) |
53 |
| - function ∇__internal_solve_up(_args...) |
| 60 | + ReverseDiffOriginator(), alg, args...; kwargs...) |
| 61 | + function ∇SimpleNonlinearSolve.__internal_solve_up(_args...) |
54 | 62 | ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
|
55 | 63 | return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
|
56 | 64 | end
|
57 |
| - return Array(out), ∇__internal_solve_up |
| 65 | + return Array(out), ∇SimpleNonlinearSolve.__internal_solve_up |
58 | 66 | end
|
59 | 67 |
|
60 | 68 | end
|
0 commit comments