Skip to content

Commit e93d16a

Browse files
gerlerogdalle
andauthored
Add dual number–based second derivatives for ForwardDiff (#310)
* Add dual number-based implementations for ForwardDiff scalar derivatives * Use functions from utils.jl * Fixup * Reorder * Add preparation to see error * Add derivative method * Add derivative! and value_and_derivative! methods * Fixup * Drop derivative and derivative! * Add prepare_second_derivative and value_derivative_and_second_derivative! * Fixup * Drop derivative stuff * Add second_derivative and second_derivative! methods * Fixup * Fixup 2 --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent a997154 commit e93d16a

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using DifferentiationInterface:
1010
HessianExtras,
1111
JacobianExtras,
1212
NoDerivativeExtras,
13+
NoSecondDerivativeExtras,
1314
PushforwardExtras
1415
using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
1516
using ForwardDiff:

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,58 @@ function DI.pushforward!(
6161
return dy
6262
end
6363

64+
## Second derivative
65+
66+
function DI.prepare_second_derivative(f::F, backend::AutoForwardDiff, x) where {F}
67+
return NoSecondDerivativeExtras()
68+
end
69+
70+
function DI.second_derivative(
71+
f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
72+
) where {F}
73+
T = tag_type(f, backend, x)
74+
xdual = make_dual(T, x, one(x))
75+
T2 = tag_type(f, backend, xdual)
76+
ydual = f(make_dual(T2, xdual, one(xdual)))
77+
return myderivative(T, myderivative(T2, ydual))
78+
end
79+
80+
function DI.second_derivative!(
81+
f::F, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
82+
) where {F}
83+
T = tag_type(f, backend, x)
84+
xdual = make_dual(T, x, one(x))
85+
T2 = tag_type(f, backend, xdual)
86+
ydual = f(make_dual(T2, xdual, one(xdual)))
87+
return myderivative!(T, der2, myderivative(T2, ydual))
88+
end
89+
90+
function DI.value_derivative_and_second_derivative(
91+
f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
92+
) where {F}
93+
T = tag_type(f, backend, x)
94+
xdual = make_dual(T, x, one(x))
95+
T2 = tag_type(f, backend, xdual)
96+
ydual = f(make_dual(T2, xdual, one(xdual)))
97+
y = myvalue(T, myvalue(T2, ydual))
98+
der = myderivative(T, myvalue(T2, ydual))
99+
der2 = myderivative(T, myderivative(T2, ydual))
100+
return y, der, der2
101+
end
102+
103+
function DI.value_derivative_and_second_derivative!(
104+
f::F, der, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
105+
) where {F}
106+
T = tag_type(f, backend, x)
107+
xdual = make_dual(T, x, one(x))
108+
T2 = tag_type(f, backend, xdual)
109+
ydual = f(make_dual(T2, xdual, one(xdual)))
110+
y = myvalue(T, myvalue(T2, ydual))
111+
myderivative!(T, der, myvalue(T2, ydual))
112+
myderivative!(T, der2, myderivative(T2, ydual))
113+
return y, der, der2
114+
end
115+
64116
## Gradient
65117

66118
struct ForwardDiffGradientExtras{C} <: GradientExtras

0 commit comments

Comments
 (0)