Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 1e4eb4b

Browse files
Merge branch 'main' into ys/itp
2 parents 780480d + f20f5e4 commit 1e4eb4b

15 files changed

+583
-238
lines changed

Project.toml

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,38 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "0.1.16"
4+
version = "0.1.17"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
99
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
13+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1214
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
13-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1415
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
15-
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1616
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1717

1818
[weakdeps]
1919
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2020

2121
[extensions]
22-
SimpleBatchedNonlinearSolveExt = "NNlib"
22+
SimpleNonlinearSolveNNlibExt = "NNlib"
2323

2424
[compat]
2525
ArrayInterface = "6, 7"
26-
DiffEqBase = "6.123.0"
26+
DiffEqBase = "6.126"
2727
FiniteDiff = "2"
2828
ForwardDiff = "0.10.3"
29-
NNlib = "0.8"
29+
NNlib = "0.8, 0.9"
30+
PackageExtensionCompat = "1"
31+
PrecompileTools = "1"
3032
Reexport = "0.2, 1"
31-
Requires = "1"
3233
SciMLBase = "1.73"
33-
PrecompileTools = "1"
3434
StaticArraysCore = "1.4"
3535
julia = "1.6"
3636

3737
[extras]
38-
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3938
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
40-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
41-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
42-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
43-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
44-
45-
[targets]
46-
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"]

ext/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 0 additions & 90 deletions
This file was deleted.

ext/SimpleNonlinearSolveNNlibExt.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
module SimpleNonlinearSolveNNlibExt
2+
3+
using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase
4+
import SimpleNonlinearSolve: _construct_batched_problem_structure,
5+
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
6+
7+
function __init__()
8+
SimpleNonlinearSolve.NNlibExtLoaded[] = true
9+
return
10+
end
11+
12+
@views function SciMLBase.__solve(prob::NonlinearProblem,
13+
alg::BatchedBroyden;
14+
abstol = nothing,
15+
reltol = nothing,
16+
maxiters = 1000,
17+
kwargs...)
18+
iip = isinplace(prob)
19+
20+
u, f, reconstruct = _construct_batched_problem_structure(prob)
21+
L, N = size(u)
22+
23+
tc = alg.termination_condition
24+
mode = DiffEqBase.get_termination_mode(tc)
25+
26+
storage = _get_storage(mode, u)
27+
28+
xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4)
29+
T = eltype(u)
30+
31+
atol = _get_tolerance(abstol, tc.abstol, T)
32+
rtol = _get_tolerance(reltol, tc.reltol, T)
33+
termination_condition = tc(storage)
34+
35+
𝓙⁻¹ = _init_𝓙(xₙ) # L × L × N
36+
𝓙⁻¹f, xᵀ𝓙⁻¹δf, xᵀ𝓙⁻¹ = similar(𝓙⁻¹, L, N), similar(𝓙⁻¹, 1, N), similar(𝓙⁻¹, 1, L, N)
37+
38+
@maybeinplace iip fₙ₋₁=f(xₙ) u
39+
iip && (fₙ = copy(fₙ₋₁))
40+
for n in 1:maxiters
41+
batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(fₙ₋₁, L, 1, N))
42+
xₙ .= xₙ₋₁ .- 𝓙⁻¹f
43+
44+
@maybeinplace iip fₙ=f(xₙ)
45+
δx .= xₙ .- xₙ₋₁
46+
δf .= fₙ .- fₙ₋₁
47+
48+
batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N))
49+
δxᵀ = reshape(δx, 1, L, N)
50+
51+
batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N))
52+
batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹)
53+
δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
54+
batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))
55+
56+
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
57+
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
58+
return DiffEqBase.build_solution(prob,
59+
alg,
60+
reconstruct(xₙ),
61+
reconstruct(fₙ);
62+
retcode)
63+
end
64+
65+
xₙ₋₁ .= xₙ
66+
fₙ₋₁ .= fₙ
67+
end
68+
69+
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
70+
xₙ = storage.u
71+
@maybeinplace iip fₙ=f(xₙ)
72+
end
73+
74+
return DiffEqBase.build_solution(prob,
75+
alg,
76+
reconstruct(xₙ),
77+
reconstruct(fₙ);
78+
retcode = ReturnCode.MaxIters)
79+
end
80+
81+
end

src/SimpleNonlinearSolve.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,19 @@ using DiffEqBase
1010

1111
@reexport using SciMLBase
1212

13-
if !isdefined(Base, :get_extension)
14-
using Requires
15-
end
16-
13+
using PackageExtensionCompat
1714
function __init__()
18-
@static if !isdefined(Base, :get_extension)
19-
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin
20-
include("../ext/SimpleBatchedNonlinearSolveExt.jl")
21-
end
22-
end
15+
@require_extensions
2316
end
2417

18+
const NNlibExtLoaded = Ref{Bool}(false)
19+
2520
abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
2621
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
2722
abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end
2823
abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end
24+
abstract type AbstractBatchedNonlinearSolveAlgorithm <:
25+
AbstractSimpleNonlinearSolveAlgorithm end
2926

3027
include("utils.jl")
3128
include("bisection.jl")
@@ -43,6 +40,12 @@ include("halley.jl")
4340
include("alefeld.jl")
4441
include("itp.jl")
4542

43+
# Batched Solver Support
44+
include("batched/utils.jl")
45+
include("batched/raphson.jl")
46+
include("batched/dfsane.jl")
47+
include("batched/broyden.jl")
48+
4649
import PrecompileTools
4750

4851
PrecompileTools.@compile_workload begin
@@ -75,5 +78,6 @@ end
7578
# DiffEq styled algorithms
7679
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
7780
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, Itp
81+
export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane
7882

7983
end # module

src/batched/broyden.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
struct BatchedBroyden{TC <: NLSolveTerminationCondition} <:
2+
AbstractBatchedNonlinearSolveAlgorithm
3+
termination_condition::TC
4+
end
5+
6+
# Implementation of solve using Package Extensions

0 commit comments

Comments
 (0)