Skip to content

Commit 13f1859

Browse files
baggepinnengdalle
andauthored
perf: use common subexpression elimination in AutoSymbolics (#759)
* use cse in AutoSymbolics closes #758 * Format --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent fca02b8 commit 13f1859

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ function DI.prepare_pushforward_nokwarg(
1818
step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var)
1919
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))
2020

21-
res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false))
21+
res = build_function(
22+
pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true
23+
)
2224
(pf_exe, pf_exe!) = if res isa Tuple
2325
res
2426
elseif res isa RuntimeGeneratedFunction
@@ -102,7 +104,7 @@ function DI.prepare_derivative_nokwarg(
102104
context_vars = variablize(contexts)
103105
der_var = derivative(f(x_var, context_vars...), x_var)
104106

105-
res = build_function(der_var, x_var, context_vars...; expression=Val(false))
107+
res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true)
106108
(der_exe, der_exe!) = if res isa Tuple
107109
res
108110
elseif res isa RuntimeGeneratedFunction
@@ -177,7 +179,9 @@ function DI.prepare_gradient_nokwarg(
177179
# Symbolic.gradient only accepts vectors
178180
grad_var = gradient(f(x_var, context_vars...), vec(x_var))
179181

180-
res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false))
182+
res = build_function(
183+
grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true
184+
)
181185
(grad_exe, grad_exe!) = res
182186
return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!)
183187
end
@@ -254,7 +258,7 @@ function DI.prepare_jacobian_nokwarg(
254258
jacobian(f(x_var, context_vars...), x_var)
255259
end
256260

257-
res = build_function(jac_var, x_var, context_vars...; expression=Val(false))
261+
res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true)
258262
(jac_exe, jac_exe!) = res
259263
return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!)
260264
end
@@ -333,7 +337,9 @@ function DI.prepare_hessian_nokwarg(
333337
hessian(f(x_var, context_vars...), vec(x_var))
334338
end
335339

336-
res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false))
340+
res = build_function(
341+
hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true
342+
)
337343
(hess_exe, hess_exe!) = res
338344

339345
gradient_prep = DI.prepare_gradient_nokwarg(
@@ -420,7 +426,12 @@ function DI.prepare_hvp_nokwarg(
420426
hvp_vec_var = hess_var * vec(dx_var)
421427

422428
res = build_function(
423-
hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false)
429+
hvp_vec_var,
430+
vec(x_var),
431+
vec(dx_var),
432+
context_vars...;
433+
expression=Val(false),
434+
cse=true,
424435
)
425436
(hvp_exe, hvp_exe!) = res
426437

@@ -508,7 +519,7 @@ function DI.prepare_second_derivative_nokwarg(
508519
der_var = derivative(f(x_var, context_vars...), x_var)
509520
der2_var = derivative(der_var, x_var)
510521

511-
res = build_function(der2_var, x_var, context_vars...; expression=Val(false))
522+
res = build_function(der2_var, x_var, context_vars...; expression=Val(false), cse=true)
512523
(der2_exe, der2_exe!) = if res isa Tuple
513524
res
514525
elseif res isa RuntimeGeneratedFunction

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ function DI.prepare_pushforward_nokwarg(
2626
step_der_var = derivative(y_var, t_var)
2727
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))
2828

29-
res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false))
29+
res = build_function(
30+
pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true
31+
)
3032
(pushforward_exe, pushforward_exe!) = res
3133
return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!)
3234
end
@@ -114,7 +116,7 @@ function DI.prepare_derivative_nokwarg(
114116
f!(y_var, x_var, context_vars...)
115117
der_var = derivative(y_var, x_var)
116118

117-
res = build_function(der_var, x_var, context_vars...; expression=Val(false))
119+
res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true)
118120
(der_exe, der_exe!) = res
119121
return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!)
120122
end
@@ -201,7 +203,7 @@ function DI.prepare_jacobian_nokwarg(
201203
jacobian(y_var, x_var)
202204
end
203205

204-
res = build_function(jac_var, x_var, context_vars...; expression=Val(false))
206+
res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true)
205207
(jac_exe, jac_exe!) = res
206208
return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!)
207209
end

0 commit comments

Comments
 (0)