Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 3422e1c

Browse files
Merge pull request #76 from avik-pal/ap/inplace_raphson
Add support for inplace BatchedSimpleNewtonRaphson
2 parents cad98b6 + 0576273 commit 3422e1c

File tree

4 files changed

+43
-15
lines changed

4 files changed

+43
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "0.1.18"
4+
version = "0.1.19"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/batched/raphson.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ end
2020
function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphson;
2121
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
2222
iip = SciMLBase.isinplace(prob)
23-
@assert !iip "BatchedSimpleNewtonRaphson currently only supports out-of-place nonlinear problems."
23+
iip &&
24+
@assert alg_autodiff(alg) "Inplace BatchedSimpleNewtonRaphson currently only supports autodiff."
2425
u, f, reconstruct = _construct_batched_problem_structure(prob)
2526

2627
tc = alg.termination_condition
@@ -35,12 +36,26 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphs
3536
rtol = _get_tolerance(reltol, tc.reltol, T)
3637
termination_condition = tc(storage)
3738

39+
if iip
40+
𝓙 = similar(xₙ, length(xₙ), length(xₙ))
41+
fₙ = similar(xₙ)
42+
jac_cfg = ForwardDiff.JacobianConfig(f, fₙ, xₙ)
43+
end
44+
3845
for i in 1:maxiters
39-
if alg_autodiff(alg)
40-
fₙ, 𝓙 = value_derivative(f, xₙ)
46+
if iip
47+
value_derivative!(𝓙, fₙ, f, xₙ, jac_cfg)
4148
else
42-
fₙ = f(xₙ)
43-
𝓙 = FiniteDiff.finite_difference_jacobian(f, xₙ, diff_type(alg), eltype(xₙ), fₙ)
49+
if alg_autodiff(alg)
50+
fₙ, 𝓙 = value_derivative(f, xₙ)
51+
else
52+
fₙ = f(xₙ)
53+
𝓙 = FiniteDiff.finite_difference_jacobian(f,
54+
xₙ,
55+
diff_type(alg),
56+
eltype(xₙ),
57+
fₙ)
58+
end
4459
end
4560

4661
iszero(fₙ) && return DiffEqBase.build_solution(prob,
@@ -66,7 +81,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphs
6681

6782
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
6883
xₙ = storage.u
69-
fₙ = f(xₙ)
84+
@maybeinplace iip fₙ=f(xₙ)
7085
end
7186

7287
return DiffEqBase.build_solution(prob,

src/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ function value_derivative(f::F, x::R) where {F, R}
3030
end
3131
value_derivative(f::F, x::AbstractArray) where {F} = f(x), ForwardDiff.jacobian(f, x)
3232

33+
"""
34+
value_derivative!(J, y, f!, x, cfg = JacobianConfig(f!, y, x))
35+
36+
Inplace version of [`SimpleNonlinearSolve.value_derivative`](@ref).
37+
"""
38+
function value_derivative!(J::AbstractMatrix,
39+
y::AbstractArray,
40+
f!::F,
41+
x::AbstractArray,
42+
cfg::ForwardDiff.JacobianConfig = ForwardDiff.JacobianConfig(f!, y, x)) where {F}
43+
ForwardDiff.jacobian!(J, f!, y, x, cfg)
44+
return y, J
45+
end
46+
3347
value(x) = x
3448
value(x::Dual) = ForwardDiff.value(x)
3549
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

test/inplace.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,28 @@ using SimpleNonlinearSolve,
22
StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test,
33
NNlib
44

5-
# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane
5+
# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane, BatchedSimpleNewtonRaphson
66
function f!(du::AbstractArray{<:Number, N},
77
u::AbstractArray{<:Number, N},
88
p::AbstractVector) where {N}
99
u_ = reshape(u, :, size(u, N))
10-
du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :),
11-
ntuple(_ -> 1, N - 1)...,
12-
size(u, N))
10+
du .= reshape(sum(abs2, u_; dims = 1) .- u_ .- reshape(p, 1, :), size(u))
1311
return du
1412
end
1513

1614
function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector)
17-
du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :)
15+
du .= sum(abs2, u; dims = 1) .- u .- reshape(p, 1, :)
1816
return du
1917
end
2018

2119
function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector)
22-
du .= sum(abs2, u) .- p
20+
du .= sum(abs2, u) .- u .- p
2321
return du
2422
end
2523

26-
@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true),
27-
SimpleDFSane(batched = true))
24+
@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(; batched = true),
25+
SimpleDFSane(; batched = true),
26+
SimpleNewtonRaphson(; batched = true))
2827
@testset "T: $T" for T in (Float32, Float64)
2928
p = rand(T, 5)
3029
@testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5))

0 commit comments

Comments
 (0)