@@ -12,7 +12,7 @@ function DI.prepare_pushforward_nokwarg(
12
12
strict:: Val , f, backend:: AutoFiniteDiff , x, tx:: NTuple , contexts:: Vararg{DI.Context,C} ;
13
13
) where {C}
14
14
_sig = DI. signature (f, backend, x, tx, contexts... ; strict)
15
- fc = DI. with_contexts (f, contexts... )
15
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
16
16
y = fc (x)
17
17
cache = if x isa Number || y isa Number
18
18
nothing
@@ -89,7 +89,7 @@ function DI.pushforward(
89
89
) where {SIG,C}
90
90
DI. check_prep (f, prep, backend, x, tx, contexts... )
91
91
(; relstep, absstep, dir) = prep
92
- fc = DI. with_contexts (f, contexts... )
92
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
93
93
ty = map (tx) do dx
94
94
finite_difference_jvp (fc, x, dx, prep. cache; relstep, absstep, dir)
95
95
end
@@ -106,7 +106,7 @@ function DI.value_and_pushforward(
106
106
) where {SIG,C}
107
107
DI. check_prep (f, prep, backend, x, tx, contexts... )
108
108
(; relstep, absstep, dir) = prep
109
- fc = DI. with_contexts (f, contexts... )
109
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
110
110
y = fc (x)
111
111
ty = map (tx) do dx
112
112
finite_difference_jvp (fc, x, dx, prep. cache, y; relstep, absstep, dir)
@@ -128,7 +128,7 @@ function DI.prepare_derivative_nokwarg(
128
128
strict:: Val , f, backend:: AutoFiniteDiff , x, contexts:: Vararg{DI.Context,C}
129
129
) where {C}
130
130
_sig = DI. signature (f, backend, x, contexts... ; strict)
131
- fc = DI. with_contexts (f, contexts... )
131
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
132
132
y = fc (x)
133
133
cache = if y isa Number
134
134
nothing
@@ -161,7 +161,7 @@ function DI.derivative(
161
161
) where {SIG,C}
162
162
DI. check_prep (f, prep, backend, x, contexts... )
163
163
(; relstep, absstep, dir) = prep
164
- fc = DI. with_contexts (f, contexts... )
164
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
165
165
return finite_difference_derivative (fc, x, fdtype (backend); relstep, absstep, dir)
166
166
end
167
167
@@ -174,7 +174,7 @@ function DI.value_and_derivative(
174
174
) where {SIG,C}
175
175
DI. check_prep (f, prep, backend, x, contexts... )
176
176
(; relstep, absstep, dir) = prep
177
- fc = DI. with_contexts (f, contexts... )
177
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
178
178
y = fc (x)
179
179
return (
180
180
y,
@@ -195,7 +195,7 @@ function DI.derivative(
195
195
) where {SIG,C}
196
196
DI. check_prep (f, prep, backend, x, contexts... )
197
197
(; relstep, absstep, dir) = prep
198
- fc = DI. with_contexts (f, contexts... )
198
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
199
199
return finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir)
200
200
end
201
201
@@ -209,7 +209,7 @@ function DI.derivative!(
209
209
) where {SIG,C}
210
210
DI. check_prep (f, prep, backend, x, contexts... )
211
211
(; relstep, absstep, dir) = prep
212
- fc = DI. with_contexts (f, contexts... )
212
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
213
213
return finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep, dir)
214
214
end
215
215
@@ -221,7 +221,7 @@ function DI.value_and_derivative(
221
221
contexts:: Vararg{DI.Context,C} ,
222
222
) where {SIG,C}
223
223
DI. check_prep (f, prep, backend, x, contexts... )
224
- fc = DI. with_contexts (f, contexts... )
224
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
225
225
(; relstep, absstep, dir) = prep
226
226
y = fc (x)
227
227
return (y, finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir))
@@ -237,7 +237,7 @@ function DI.value_and_derivative!(
237
237
) where {SIG,C}
238
238
DI. check_prep (f, prep, backend, x, contexts... )
239
239
(; relstep, absstep, dir) = prep
240
- fc = DI. with_contexts (f, contexts... )
240
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
241
241
return (
242
242
fc (x), finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep, dir)
243
243
)
@@ -257,7 +257,7 @@ function DI.prepare_gradient_nokwarg(
257
257
strict:: Val , f, backend:: AutoFiniteDiff , x, contexts:: Vararg{DI.Context,C}
258
258
) where {C}
259
259
_sig = DI. signature (f, backend, x, contexts... ; strict)
260
- fc = DI. with_contexts (f, contexts... )
260
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
261
261
y = fc (x)
262
262
df = zero (y) .* x
263
263
cache = GradientCache (df, x, fdtype (backend))
@@ -284,7 +284,7 @@ function DI.gradient(
284
284
) where {C}
285
285
DI. check_prep (f, prep, backend, x, contexts... )
286
286
(; relstep, absstep, dir) = prep
287
- fc = DI. with_contexts (f, contexts... )
287
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
288
288
return finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir)
289
289
end
290
290
@@ -297,7 +297,7 @@ function DI.value_and_gradient(
297
297
) where {C}
298
298
DI. check_prep (f, prep, backend, x, contexts... )
299
299
(; relstep, absstep, dir) = prep
300
- fc = DI. with_contexts (f, contexts... )
300
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
301
301
return fc (x), finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir)
302
302
end
303
303
@@ -311,7 +311,7 @@ function DI.gradient!(
311
311
) where {C}
312
312
DI. check_prep (f, prep, backend, x, contexts... )
313
313
(; relstep, absstep, dir) = prep
314
- fc = DI. with_contexts (f, contexts... )
314
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
315
315
return finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep, dir)
316
316
end
317
317
@@ -325,7 +325,7 @@ function DI.value_and_gradient!(
325
325
) where {C}
326
326
DI. check_prep (f, prep, backend, x, contexts... )
327
327
(; relstep, absstep, dir) = prep
328
- fc = DI. with_contexts (f, contexts... )
328
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
329
329
return (
330
330
fc (x), finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep, dir)
331
331
)
@@ -345,7 +345,7 @@ function DI.prepare_jacobian_nokwarg(
345
345
strict:: Val , f, backend:: AutoFiniteDiff , x, contexts:: Vararg{DI.Context,C}
346
346
) where {C}
347
347
_sig = DI. signature (f, backend, x, contexts... ; strict)
348
- fc = DI. with_contexts (f, contexts... )
348
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
349
349
y = fc (x)
350
350
x1 = similar (x)
351
351
fx = similar (y)
@@ -374,7 +374,7 @@ function DI.jacobian(
374
374
) where {C}
375
375
DI. check_prep (f, prep, backend, x, contexts... )
376
376
(; relstep, absstep, dir) = prep
377
- fc = DI. with_contexts (f, contexts... )
377
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
378
378
return finite_difference_jacobian (fc, x, prep. cache; relstep, absstep, dir)
379
379
end
380
380
@@ -386,7 +386,7 @@ function DI.value_and_jacobian(
386
386
contexts:: Vararg{DI.Context,C} ,
387
387
) where {C}
388
388
DI. check_prep (f, prep, backend, x, contexts... )
389
- fc = DI. with_contexts (f, contexts... )
389
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
390
390
(; relstep, absstep, dir) = prep
391
391
y = fc (x)
392
392
return (y, finite_difference_jacobian (fc, x, prep. cache, y; relstep, absstep, dir))
@@ -402,7 +402,7 @@ function DI.jacobian!(
402
402
) where {C}
403
403
DI. check_prep (f, prep, backend, x, contexts... )
404
404
(; relstep, absstep, dir) = prep
405
- fc = DI. with_contexts (f, contexts... )
405
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
406
406
return copyto! (
407
407
jac,
408
408
finite_difference_jacobian (
@@ -421,7 +421,7 @@ function DI.value_and_jacobian!(
421
421
) where {C}
422
422
DI. check_prep (f, prep, backend, x, contexts... )
423
423
(; relstep, absstep, dir) = prep
424
- fc = DI. with_contexts (f, contexts... )
424
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
425
425
y = fc (x)
426
426
return (
427
427
y,
@@ -450,7 +450,7 @@ function DI.prepare_hessian_nokwarg(
450
450
strict:: Val , f, backend:: AutoFiniteDiff , x, contexts:: Vararg{DI.Context,C}
451
451
) where {C}
452
452
_sig = DI. signature (f, backend, x, contexts... ; strict)
453
- fc = DI. with_contexts (f, contexts... )
453
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
454
454
y = fc (x)
455
455
df = zero (y) .* x
456
456
gradient_cache = GradientCache (df, x, fdtype (backend))
@@ -481,7 +481,7 @@ function DI.hessian(
481
481
) where {C}
482
482
DI. check_prep (f, prep, backend, x, contexts... )
483
483
(; relstep_h, absstep_h) = prep
484
- fc = DI. with_contexts (f, contexts... )
484
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
485
485
return finite_difference_hessian (
486
486
fc, x, prep. hessian_cache; relstep= relstep_h, absstep= absstep_h
487
487
)
@@ -497,7 +497,7 @@ function DI.hessian!(
497
497
) where {C}
498
498
DI. check_prep (f, prep, backend, x, contexts... )
499
499
(; relstep_h, absstep_h) = prep
500
- fc = DI. with_contexts (f, contexts... )
500
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
501
501
return finite_difference_hessian! (
502
502
hess, fc, x, prep. hessian_cache; relstep= relstep_h, absstep= absstep_h
503
503
)
@@ -512,7 +512,7 @@ function DI.value_gradient_and_hessian(
512
512
) where {C}
513
513
DI. check_prep (f, prep, backend, x, contexts... )
514
514
(; relstep_g, absstep_g, relstep_h, absstep_h) = prep
515
- fc = DI. with_contexts (f, contexts... )
515
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
516
516
grad = finite_difference_gradient (
517
517
fc, x, prep. gradient_cache; relstep= relstep_g, absstep= absstep_g
518
518
)
@@ -533,7 +533,7 @@ function DI.value_gradient_and_hessian!(
533
533
) where {C}
534
534
DI. check_prep (f, prep, backend, x, contexts... )
535
535
(; relstep_g, absstep_g, relstep_h, absstep_h) = prep
536
- fc = DI. with_contexts (f, contexts... )
536
+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
537
537
finite_difference_gradient! (
538
538
grad, fc, x, prep. gradient_cache; relstep= relstep_g, absstep= absstep_g
539
539
)
0 commit comments