Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 283ca63

Browse files
Merge pull request #87 from SciML/matrix_resizing
Add matrix resizing and fix cases with u0 as a matrix
2 parents a82c317 + bae8815 commit 283ca63

File tree

7 files changed

+30
-16
lines changed

7 files changed

+30
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "0.1.22"
4+
version = "0.1.23"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/broyden.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;
5858
xₙ₋₁ = x
5959
fₙ₋₁ = fₙ
6060
for _ in 1:maxiters
61-
xₙ = xₙ₋₁ - J⁻¹ * fₙ₋₁
61+
xₙ = xₙ₋₁ - _restructure(xₙ₋₁, J⁻¹ * _vec(fₙ₋₁))
6262
fₙ = f(xₙ)
6363
Δxₙ = xₙ - xₙ₋₁
6464
Δfₙ = fₙ - fₙ₋₁
65-
J⁻¹Δfₙ = J⁻¹ * Δfₙ
66-
J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹)
65+
J⁻¹Δfₙ = _restructure(Δfₙ, J⁻¹ * _vec(Δfₙ))
66+
J⁻¹ += _restructure(J⁻¹, ((_vec(Δxₙ) .- _vec(J⁻¹Δfₙ)) ./ (_vec(Δxₙ)' * _vec(J⁻¹Δfₙ))) * (_vec(Δxₙ)' * J⁻¹))
6767

6868
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
6969
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)

src/klement.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
7575
F = lu(J, check = false)
7676
end
7777

78-
tmp = F \ fₙ₋₁
78+
tmp = _restructure(fₙ₋₁, F \ _vec(fₙ₋₁))
7979
xₙ = xₙ₋₁ - tmp
8080
fₙ = f(xₙ)
8181

@@ -92,10 +92,10 @@ function SciMLBase.__solve(prob::NonlinearProblem,
9292
Δfₙ = fₙ - fₙ₋₁
9393

9494
# Prevent division by 0
95-
denominator = max.(J' .^ 2 * Δxₙ .^ 2, 1e-9)
95+
denominator = _restructure(Δxₙ, max.(J' .^ 2 * _vec(Δxₙ) .^ 2, 1e-9))
9696

97-
k = (Δfₙ - J * Δxₙ) ./ denominator
98-
J += (k * Δxₙ' .* J) * J
97+
k = (Δfₙ - _restructure(Δxₙ, J * _vec(Δxₙ))) ./ denominator
98+
J += (_vec(k) * _vec(Δxₙ)' .* J) * J
9999

100100
xₙ₋₁ = xₙ
101101
fₙ₋₁ = fₙ

src/raphson.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
100100
end
101101
iszero(fx) &&
102102
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
103-
Δx = dfx \ fx
103+
Δx = _restructure(fx, dfx \ _vec(fx))
104104
x -= Δx
105105
if isapprox(x, xo, atol = atol, rtol = rtol)
106106
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,10 @@ function dogleg_method(H, g, Δ)
8282
tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd
8383
return δsd + tau * δN_δsd
8484
end
85+
86+
@inline _vec(v) = vec(v)
87+
@inline _vec(v::Number) = v
88+
@inline _vec(v::AbstractVector) = v
89+
90+
@inline _restructure(y::Number, x::Number) = x
91+
@inline _restructure(y, x) = ArrayInterface.restructure(y,x)

test/matrix_resizing_tests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using SimpleNonlinearSolve
2+
3+
ff(u, p) = u .* u .- p
4+
u0 = rand(2,2)
5+
p = 2.0
6+
vecprob = NonlinearProblem(ff, vec(u0), p)
7+
prob = NonlinearProblem(ff, u0, p)
8+
9+
for alg in (Klement(), Broyden(), SimpleNewtonRaphson())
10+
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
11+
end

test/runtests.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,8 @@ const GROUP = get(ENV, "GROUP", "All")
44

55
@time begin
66
if GROUP == "All" || GROUP == "Core"
7-
@time @safetestset "Basic Tests + Some AD" begin
8-
include("basictests.jl")
9-
end
10-
11-
@time @safetestset "Inplace Tests" begin
12-
include("inplace.jl")
13-
end
7+
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
8+
@time @safetestset "Inplace Tests" include("inplace.jl")
9+
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
1410
end
1511
end

0 commit comments

Comments
 (0)