Skip to content

Commit e57dd0c

Browse files
authored
feat: add new ConstantOrCache context (#749)
* feat: add new `ConstantOrCache` context * Coverage and docs * Fix Enzyme
1 parent 7d252ad commit e57dd0c

File tree

32 files changed

+351
-196
lines changed

32 files changed

+351
-196
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.46"
4+
version = "0.6.47"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ DifferentiationInterface
1414
Context
1515
Constant
1616
Cache
17+
ConstantOrCache
1718
```
1819

1920
## First order

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f)
5454
end
5555

5656
@inline function _translate(
57-
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.PrepContext}
57+
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.GeneralizedConstantOrCache}
5858
) where {B}
5959
if B == 1
6060
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function DI.prepare_pushforward_nokwarg(
1212
strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C};
1313
) where {C}
1414
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
15-
fc = DI.with_contexts(f, contexts...)
15+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
1616
y = fc(x)
1717
cache = if x isa Number || y isa Number
1818
nothing
@@ -89,7 +89,7 @@ function DI.pushforward(
8989
) where {SIG,C}
9090
DI.check_prep(f, prep, backend, x, tx, contexts...)
9191
(; relstep, absstep, dir) = prep
92-
fc = DI.with_contexts(f, contexts...)
92+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
9393
ty = map(tx) do dx
9494
finite_difference_jvp(fc, x, dx, prep.cache; relstep, absstep, dir)
9595
end
@@ -106,7 +106,7 @@ function DI.value_and_pushforward(
106106
) where {SIG,C}
107107
DI.check_prep(f, prep, backend, x, tx, contexts...)
108108
(; relstep, absstep, dir) = prep
109-
fc = DI.with_contexts(f, contexts...)
109+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
110110
y = fc(x)
111111
ty = map(tx) do dx
112112
finite_difference_jvp(fc, x, dx, prep.cache, y; relstep, absstep, dir)
@@ -128,7 +128,7 @@ function DI.prepare_derivative_nokwarg(
128128
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
129129
) where {C}
130130
_sig = DI.signature(f, backend, x, contexts...; strict)
131-
fc = DI.with_contexts(f, contexts...)
131+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
132132
y = fc(x)
133133
cache = if y isa Number
134134
nothing
@@ -161,7 +161,7 @@ function DI.derivative(
161161
) where {SIG,C}
162162
DI.check_prep(f, prep, backend, x, contexts...)
163163
(; relstep, absstep, dir) = prep
164-
fc = DI.with_contexts(f, contexts...)
164+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
165165
return finite_difference_derivative(fc, x, fdtype(backend); relstep, absstep, dir)
166166
end
167167

@@ -174,7 +174,7 @@ function DI.value_and_derivative(
174174
) where {SIG,C}
175175
DI.check_prep(f, prep, backend, x, contexts...)
176176
(; relstep, absstep, dir) = prep
177-
fc = DI.with_contexts(f, contexts...)
177+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
178178
y = fc(x)
179179
return (
180180
y,
@@ -195,7 +195,7 @@ function DI.derivative(
195195
) where {SIG,C}
196196
DI.check_prep(f, prep, backend, x, contexts...)
197197
(; relstep, absstep, dir) = prep
198-
fc = DI.with_contexts(f, contexts...)
198+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
199199
return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
200200
end
201201

@@ -209,7 +209,7 @@ function DI.derivative!(
209209
) where {SIG,C}
210210
DI.check_prep(f, prep, backend, x, contexts...)
211211
(; relstep, absstep, dir) = prep
212-
fc = DI.with_contexts(f, contexts...)
212+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
213213
return finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir)
214214
end
215215

@@ -221,7 +221,7 @@ function DI.value_and_derivative(
221221
contexts::Vararg{DI.Context,C},
222222
) where {SIG,C}
223223
DI.check_prep(f, prep, backend, x, contexts...)
224-
fc = DI.with_contexts(f, contexts...)
224+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
225225
(; relstep, absstep, dir) = prep
226226
y = fc(x)
227227
return (y, finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir))
@@ -237,7 +237,7 @@ function DI.value_and_derivative!(
237237
) where {SIG,C}
238238
DI.check_prep(f, prep, backend, x, contexts...)
239239
(; relstep, absstep, dir) = prep
240-
fc = DI.with_contexts(f, contexts...)
240+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
241241
return (
242242
fc(x), finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir)
243243
)
@@ -257,7 +257,7 @@ function DI.prepare_gradient_nokwarg(
257257
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
258258
) where {C}
259259
_sig = DI.signature(f, backend, x, contexts...; strict)
260-
fc = DI.with_contexts(f, contexts...)
260+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
261261
y = fc(x)
262262
df = zero(y) .* x
263263
cache = GradientCache(df, x, fdtype(backend))
@@ -284,7 +284,7 @@ function DI.gradient(
284284
) where {C}
285285
DI.check_prep(f, prep, backend, x, contexts...)
286286
(; relstep, absstep, dir) = prep
287-
fc = DI.with_contexts(f, contexts...)
287+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
288288
return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
289289
end
290290

@@ -297,7 +297,7 @@ function DI.value_and_gradient(
297297
) where {C}
298298
DI.check_prep(f, prep, backend, x, contexts...)
299299
(; relstep, absstep, dir) = prep
300-
fc = DI.with_contexts(f, contexts...)
300+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
301301
return fc(x), finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)
302302
end
303303

@@ -311,7 +311,7 @@ function DI.gradient!(
311311
) where {C}
312312
DI.check_prep(f, prep, backend, x, contexts...)
313313
(; relstep, absstep, dir) = prep
314-
fc = DI.with_contexts(f, contexts...)
314+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
315315
return finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir)
316316
end
317317

@@ -325,7 +325,7 @@ function DI.value_and_gradient!(
325325
) where {C}
326326
DI.check_prep(f, prep, backend, x, contexts...)
327327
(; relstep, absstep, dir) = prep
328-
fc = DI.with_contexts(f, contexts...)
328+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
329329
return (
330330
fc(x), finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir)
331331
)
@@ -345,7 +345,7 @@ function DI.prepare_jacobian_nokwarg(
345345
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
346346
) where {C}
347347
_sig = DI.signature(f, backend, x, contexts...; strict)
348-
fc = DI.with_contexts(f, contexts...)
348+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
349349
y = fc(x)
350350
x1 = similar(x)
351351
fx = similar(y)
@@ -374,7 +374,7 @@ function DI.jacobian(
374374
) where {C}
375375
DI.check_prep(f, prep, backend, x, contexts...)
376376
(; relstep, absstep, dir) = prep
377-
fc = DI.with_contexts(f, contexts...)
377+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
378378
return finite_difference_jacobian(fc, x, prep.cache; relstep, absstep, dir)
379379
end
380380

@@ -386,7 +386,7 @@ function DI.value_and_jacobian(
386386
contexts::Vararg{DI.Context,C},
387387
) where {C}
388388
DI.check_prep(f, prep, backend, x, contexts...)
389-
fc = DI.with_contexts(f, contexts...)
389+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
390390
(; relstep, absstep, dir) = prep
391391
y = fc(x)
392392
return (y, finite_difference_jacobian(fc, x, prep.cache, y; relstep, absstep, dir))
@@ -402,7 +402,7 @@ function DI.jacobian!(
402402
) where {C}
403403
DI.check_prep(f, prep, backend, x, contexts...)
404404
(; relstep, absstep, dir) = prep
405-
fc = DI.with_contexts(f, contexts...)
405+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
406406
return copyto!(
407407
jac,
408408
finite_difference_jacobian(
@@ -421,7 +421,7 @@ function DI.value_and_jacobian!(
421421
) where {C}
422422
DI.check_prep(f, prep, backend, x, contexts...)
423423
(; relstep, absstep, dir) = prep
424-
fc = DI.with_contexts(f, contexts...)
424+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
425425
y = fc(x)
426426
return (
427427
y,
@@ -450,7 +450,7 @@ function DI.prepare_hessian_nokwarg(
450450
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
451451
) where {C}
452452
_sig = DI.signature(f, backend, x, contexts...; strict)
453-
fc = DI.with_contexts(f, contexts...)
453+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
454454
y = fc(x)
455455
df = zero(y) .* x
456456
gradient_cache = GradientCache(df, x, fdtype(backend))
@@ -481,7 +481,7 @@ function DI.hessian(
481481
) where {C}
482482
DI.check_prep(f, prep, backend, x, contexts...)
483483
(; relstep_h, absstep_h) = prep
484-
fc = DI.with_contexts(f, contexts...)
484+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
485485
return finite_difference_hessian(
486486
fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h
487487
)
@@ -497,7 +497,7 @@ function DI.hessian!(
497497
) where {C}
498498
DI.check_prep(f, prep, backend, x, contexts...)
499499
(; relstep_h, absstep_h) = prep
500-
fc = DI.with_contexts(f, contexts...)
500+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
501501
return finite_difference_hessian!(
502502
hess, fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h
503503
)
@@ -512,7 +512,7 @@ function DI.value_gradient_and_hessian(
512512
) where {C}
513513
DI.check_prep(f, prep, backend, x, contexts...)
514514
(; relstep_g, absstep_g, relstep_h, absstep_h) = prep
515-
fc = DI.with_contexts(f, contexts...)
515+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
516516
grad = finite_difference_gradient(
517517
fc, x, prep.gradient_cache; relstep=relstep_g, absstep=absstep_g
518518
)
@@ -533,7 +533,7 @@ function DI.value_gradient_and_hessian!(
533533
) where {C}
534534
DI.check_prep(f, prep, backend, x, contexts...)
535535
(; relstep_g, absstep_g, relstep_h, absstep_h) = prep
536-
fc = DI.with_contexts(f, contexts...)
536+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
537537
finite_difference_gradient!(
538538
grad, fc, x, prep.gradient_cache; relstep=relstep_g, absstep=absstep_g
539539
)

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function DI.pushforward(
8080
) where {SIG,C}
8181
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
8282
(; relstep, absstep, dir) = prep
83-
fc! = DI.with_contexts(f!, contexts...)
83+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
8484
ty = map(tx) do dx
8585
dy = similar(y)
8686
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir)
@@ -100,7 +100,7 @@ function DI.value_and_pushforward(
100100
) where {SIG,C}
101101
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
102102
(; relstep, absstep, dir) = prep
103-
fc! = DI.with_contexts(f!, contexts...)
103+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
104104
ty = map(tx) do dx
105105
dy = similar(y)
106106
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir)
@@ -122,7 +122,7 @@ function DI.pushforward!(
122122
) where {SIG,C}
123123
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
124124
(; relstep, absstep, dir) = prep
125-
fc! = DI.with_contexts(f!, contexts...)
125+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
126126
for b in eachindex(tx, ty)
127127
dx, dy = tx[b], ty[b]
128128
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir)
@@ -142,7 +142,7 @@ function DI.value_and_pushforward!(
142142
) where {SIG,C}
143143
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
144144
(; relstep, absstep, dir) = prep
145-
fc! = DI.with_contexts(f!, contexts...)
145+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
146146
for b in eachindex(tx, ty)
147147
dx, dy = tx[b], ty[b]
148148
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir)
@@ -214,7 +214,7 @@ function DI.value_and_derivative(
214214
) where {C}
215215
DI.check_prep(f!, y, prep, backend, x, contexts...)
216216
(; relstep, absstep, dir) = prep
217-
fc! = DI.with_contexts(f!, contexts...)
217+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
218218
fc!(y, x)
219219
der = finite_difference_gradient(fc!, x, prep.cache; relstep, absstep, dir)
220220
return y, der
@@ -231,7 +231,7 @@ function DI.value_and_derivative!(
231231
) where {C}
232232
DI.check_prep(f!, y, prep, backend, x, contexts...)
233233
(; relstep, absstep, dir) = prep
234-
fc! = DI.with_contexts(f!, contexts...)
234+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
235235
fc!(y, x)
236236
finite_difference_gradient!(der, fc!, x, prep.cache; relstep, absstep, dir)
237237
return y, der
@@ -247,7 +247,7 @@ function DI.derivative(
247247
) where {C}
248248
DI.check_prep(f!, y, prep, backend, x, contexts...)
249249
(; relstep, absstep, dir) = prep
250-
fc! = DI.with_contexts(f!, contexts...)
250+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
251251
fc!(y, x)
252252
der = finite_difference_gradient(fc!, x, prep.cache; relstep, absstep, dir)
253253
return der
@@ -264,7 +264,7 @@ function DI.derivative!(
264264
) where {C}
265265
DI.check_prep(f!, y, prep, backend, x, contexts...)
266266
(; relstep, absstep, dir) = prep
267-
fc! = DI.with_contexts(f!, contexts...)
267+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
268268
finite_difference_gradient!(der, fc!, x, prep.cache; relstep, absstep, dir)
269269
return der
270270
end
@@ -336,7 +336,7 @@ function DI.value_and_jacobian(
336336
) where {C}
337337
DI.check_prep(f!, y, prep, backend, x, contexts...)
338338
(; relstep, absstep, dir) = prep
339-
fc! = DI.with_contexts(f!, contexts...)
339+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
340340
jac = similar(y, length(y), length(x))
341341
finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir)
342342
fc!(y, x)
@@ -354,7 +354,7 @@ function DI.value_and_jacobian!(
354354
) where {C}
355355
DI.check_prep(f!, y, prep, backend, x, contexts...)
356356
(; relstep, absstep, dir) = prep
357-
fc! = DI.with_contexts(f!, contexts...)
357+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
358358
finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir)
359359
fc!(y, x)
360360
return y, jac
@@ -370,7 +370,7 @@ function DI.jacobian(
370370
) where {C}
371371
DI.check_prep(f!, y, prep, backend, x, contexts...)
372372
(; relstep, absstep, dir) = prep
373-
fc! = DI.with_contexts(f!, contexts...)
373+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
374374
jac = similar(y, length(y), length(x))
375375
finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir)
376376
return jac
@@ -387,7 +387,7 @@ function DI.jacobian!(
387387
) where {C}
388388
DI.check_prep(f!, y, prep, backend, x, contexts...)
389389
(; relstep, absstep, dir) = prep
390-
fc! = DI.with_contexts(f!, contexts...)
390+
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
391391
finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir)
392392
return jac
393393
end

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function DI.pushforward(
3232
contexts::Vararg{DI.Context,C},
3333
) where {C}
3434
DI.check_prep(f, prep, backend, x, tx, contexts...)
35-
fc = DI.with_contexts(f, contexts...)
35+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
3636
ty = map(tx) do dx
3737
jvp(backend.fdm, fc, (x, dx))
3838
end
@@ -75,7 +75,7 @@ function DI.pullback(
7575
contexts::Vararg{DI.Context,C},
7676
) where {C}
7777
DI.check_prep(f, prep, backend, x, ty, contexts...)
78-
fc = DI.with_contexts(f, contexts...)
78+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
7979
tx = map(ty) do dy
8080
only(j′vp(backend.fdm, fc, dy, x))
8181
end
@@ -112,7 +112,7 @@ function DI.gradient(
112112
contexts::Vararg{DI.Context,C},
113113
) where {C}
114114
DI.check_prep(f, prep, backend, x, contexts...)
115-
fc = DI.with_contexts(f, contexts...)
115+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
116116
return only(grad(backend.fdm, fc, x))
117117
end
118118

@@ -169,7 +169,7 @@ function DI.jacobian(
169169
contexts::Vararg{DI.Context,C},
170170
) where {C}
171171
DI.check_prep(f, prep, backend, x, contexts...)
172-
fc = DI.with_contexts(f, contexts...)
172+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
173173
return only(jacobian(backend.fdm, fc, x))
174174
end
175175

0 commit comments

Comments
 (0)