@@ -18,7 +18,9 @@ function DI.prepare_pushforward_nokwarg(
18
18
step_der_var = derivative (f (x_var + t_var * dx_var, context_vars... ), t_var)
19
19
pf_var = substitute (step_der_var, Dict (t_var => zero (eltype (x))))
20
20
21
- res = build_function (pf_var, x_var, dx_var, context_vars... ; expression= Val (false ))
21
+ res = build_function (
22
+ pf_var, x_var, dx_var, context_vars... ; expression= Val (false ), cse= true
23
+ )
22
24
(pf_exe, pf_exe!) = if res isa Tuple
23
25
res
24
26
elseif res isa RuntimeGeneratedFunction
@@ -102,7 +104,7 @@ function DI.prepare_derivative_nokwarg(
102
104
context_vars = variablize (contexts)
103
105
der_var = derivative (f (x_var, context_vars... ), x_var)
104
106
105
- res = build_function (der_var, x_var, context_vars... ; expression= Val (false ))
107
+ res = build_function (der_var, x_var, context_vars... ; expression= Val (false ), cse = true )
106
108
(der_exe, der_exe!) = if res isa Tuple
107
109
res
108
110
elseif res isa RuntimeGeneratedFunction
@@ -177,7 +179,9 @@ function DI.prepare_gradient_nokwarg(
177
179
# Symbolic.gradient only accepts vectors
178
180
grad_var = gradient (f (x_var, context_vars... ), vec (x_var))
179
181
180
- res = build_function (grad_var, vec (x_var), context_vars... ; expression= Val (false ))
182
+ res = build_function (
183
+ grad_var, vec (x_var), context_vars... ; expression= Val (false ), cse= true
184
+ )
181
185
(grad_exe, grad_exe!) = res
182
186
return SymbolicsOneArgGradientPrep (_sig, grad_exe, grad_exe!)
183
187
end
@@ -254,7 +258,7 @@ function DI.prepare_jacobian_nokwarg(
254
258
jacobian (f (x_var, context_vars... ), x_var)
255
259
end
256
260
257
- res = build_function (jac_var, x_var, context_vars... ; expression= Val (false ))
261
+ res = build_function (jac_var, x_var, context_vars... ; expression= Val (false ), cse = true )
258
262
(jac_exe, jac_exe!) = res
259
263
return SymbolicsOneArgJacobianPrep (_sig, jac_exe, jac_exe!)
260
264
end
@@ -333,7 +337,9 @@ function DI.prepare_hessian_nokwarg(
333
337
hessian (f (x_var, context_vars... ), vec (x_var))
334
338
end
335
339
336
- res = build_function (hess_var, vec (x_var), context_vars... ; expression= Val (false ))
340
+ res = build_function (
341
+ hess_var, vec (x_var), context_vars... ; expression= Val (false ), cse= true
342
+ )
337
343
(hess_exe, hess_exe!) = res
338
344
339
345
gradient_prep = DI. prepare_gradient_nokwarg (
@@ -420,7 +426,12 @@ function DI.prepare_hvp_nokwarg(
420
426
hvp_vec_var = hess_var * vec (dx_var)
421
427
422
428
res = build_function (
423
- hvp_vec_var, vec (x_var), vec (dx_var), context_vars... ; expression= Val (false )
429
+ hvp_vec_var,
430
+ vec (x_var),
431
+ vec (dx_var),
432
+ context_vars... ;
433
+ expression= Val (false ),
434
+ cse= true ,
424
435
)
425
436
(hvp_exe, hvp_exe!) = res
426
437
@@ -508,7 +519,7 @@ function DI.prepare_second_derivative_nokwarg(
508
519
der_var = derivative (f (x_var, context_vars... ), x_var)
509
520
der2_var = derivative (der_var, x_var)
510
521
511
- res = build_function (der2_var, x_var, context_vars... ; expression= Val (false ))
522
+ res = build_function (der2_var, x_var, context_vars... ; expression= Val (false ), cse = true )
512
523
(der2_exe, der2_exe!) = if res isa Tuple
513
524
res
514
525
elseif res isa RuntimeGeneratedFunction
0 commit comments