Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 83280b6

Browse files
Merge pull request #64 from utkarsh530/u/jac
Add support for custom jacobians
2 parents 976758f + bc57490 commit 83280b6

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

src/raphson.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ function SciMLBase.__solve(prob::NonlinearProblem,
6060
end
6161

6262
for i in 1:maxiters
63-
if alg_autodiff(alg)
63+
if DiffEqBase.has_jac(prob.f)
64+
dfx = prob.f.jac(x, prob.p)
65+
fx = f(x)
66+
elseif alg_autodiff(alg)
6467
fx, dfx = value_derivative(f, x)
6568
elseif x isa AbstractArray
6669
fx = f(x)

src/trustRegion.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,11 @@ function SciMLBase.__solve(prob::NonlinearProblem,
112112
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
113113
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
114114

115-
if alg_autodiff(alg)
115+
116+
if DiffEqBase.has_jac(prob.f)
117+
∇f = prob.f.jac(x, prob.p)
118+
F = f(x)
119+
elseif alg_autodiff(alg)
116120
F, ∇f = value_derivative(f, x)
117121
elseif x isa AbstractArray
118122
F = f(x)

test/basictests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using SimpleNonlinearSolve
22
using StaticArrays
33
using BenchmarkTools
44
using DiffEqBase
5+
using LinearAlgebra
56
using Test
67

78
const BATCHED_BROYDEN_SOLVERS = Broyden[]
@@ -489,3 +490,18 @@ for alg in (BATCHED_BROYDEN_SOLVERS..., BATCHED_LBROYDEN_SOLVERS...)
489490
@test sol.retcode == ReturnCode.Success
490491
@test abs.(sol.u) sqrt.(p)
491492
end
493+
494+
## User specified Jacobian
495+
496+
f, u0 = (u, p) -> u .* u .- p, randn(3)
497+
498+
f_jac(u, p) = begin diagm(2 * u) end
499+
500+
p = [2.0, 1.0, 5.0];
501+
502+
probN = NonlinearProblem(NonlinearFunction(f, jac = f_jac), u0, p)
503+
504+
for alg in (SimpleNewtonRaphson(), SimpleTrustRegion())
505+
sol = solve(probN, alg)
506+
@test abs.(sol.u) sqrt.(p)
507+
end

0 commit comments

Comments
 (0)