Skip to content

Commit fad13e6

Browse files
committed
Update
1 parent ed9626f commit fad13e6

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

docs/src/submodules/Nonlinear/SymbolicAD.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ julia> f = MOI.ScalarNonlinearFunction(:sin, Any[x])
185185
sin(MOI.VariableIndex(1))
186186
187187
julia> MOI.Nonlinear.SymbolicAD.derivative(f, x)
188-
*(cos(MOI.VariableIndex(1)), (true))
188+
cos(MOI.VariableIndex(1)
189189
```
190190

191191
Note that the resultant expression can often be simplified. Thus, in most cases
@@ -196,14 +196,14 @@ using it in other places:
196196
julia> x = MOI.VariableIndex(1)
197197
MOI.VariableIndex(1)
198198
199-
julia> f = MOI.ScalarNonlinearFunction(:sin, Any[x])
200-
sin(MOI.VariableIndex(1))
199+
julia> f = MOI.ScalarNonlinearFunction(:sin, Any[x + 1.0])
200+
sin(1.0 + 1.0 MOI.VariableIndex(1))
201201
202202
julia> df_dx = MOI.Nonlinear.SymbolicAD.derivative(f, x)
203-
*(cos(MOI.VariableIndex(1)), (true))
203+
*(cos(1.0 + 1.0 MOI.VariableIndex(1)), 1.0)
204204
205205
julia> MOI.Nonlinear.SymbolicAD.simplify!(df_dx)
206-
cos(MOI.VariableIndex(1))
206+
cos(1.0 + 1.0 MOI.VariableIndex(1))
207207
```
208208

209209
## `gradient_and_hessian`

src/Nonlinear/SymbolicAD/SymbolicAD.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,14 @@ that the user would never write themselves.
558558
"""
559559
const __DERIVATIVE__ = "__DERIVATIVE__"
560560

561+
# This function helps simplify df_du * du_dx in the commonn case that `du_dx`
562+
# is `true` (when u = x), or `false` (when x ∉ u).
563+
function _univariate_chain_rule(df_du, du_dx)
564+
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
565+
end
566+
567+
_univariate_chain_rule(df_du, du_dx::Bool) = ifelse(du_dx, df_du, du_dx)
568+
561569
function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
562570
if length(f.args) == 1
563571
u = only(f.args)
@@ -571,28 +579,28 @@ function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
571579
:ifelse,
572580
Any[MOI.ScalarNonlinearFunction(:>=, Any[u, 0]), 1, -1],
573581
)
574-
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
582+
return _univariate_chain_rule(df_du, du_dx)
575583
elseif f.head == :sign
576584
return false
577585
elseif f.head == :deg2rad
578586
df_du = deg2rad(1)
579-
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
587+
return _univariate_chain_rule(df_du, du_dx)
580588
elseif f.head == :rad2deg
581589
df_du = rad2deg(1)
582-
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
590+
return _univariate_chain_rule(df_du, du_dx)
583591
end
584592
for (key, df, _) in MOI.Nonlinear.SYMBOLIC_UNIVARIATE_EXPRESSIONS
585593
if key == f.head
586594
# The chain rule: d(f(g(x))) / dx = f'(g(x)) * g'(x)
587595
df_du = _replace_expression(copy(df), u)
588-
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
596+
return _univariate_chain_rule(df_du, du_dx)
589597
end
590598
end
591599
# Delay derivative until evaluation. This may result in a later
592600
# UnsupportedNonlinearOperator error, but we can't tell just yet.
593601
d_op = Symbol(__DERIVATIVE__ * "$(f.head)")
594602
df_du = MOI.ScalarNonlinearFunction(d_op, Any[u])
595-
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
603+
return _univariate_chain_rule(df_du, du_dx)
596604
end
597605
if f.head == :+
598606
# d/dx(+(args...)) = +(d/dx args)

test/Nonlinear/SymbolicAD.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ function test_derivative()
134134
return
135135
end
136136

137+
function test_derivative_univariate_simplification()
138+
x = MOI.VariableIndex(1)
139+
@test SymbolicAD.derivative(op(:sin, x), x) op(:cos, x)
140+
return
141+
end
142+
137143
function test_derivative_error()
138144
x = MOI.VariableIndex(1)
139145
f = MOI.ScalarNonlinearFunction(:foo, Any[x, x])

0 commit comments

Comments
 (0)