Skip to content

Commit 7d252ad

Browse files
authored
feat: support nested tuples of arrays as Caches (#748)
* fix: separe `prepare` from the hidden `prepare_nokwarg` * DOcs * Typing * Fix * Toggle fail fast * feat: recursive similar for caches * Recursive caches * Enzyme * Remove new tests * SCT fix * Nesting in test scens * More sophisticated testing * Fix * Coverage
1 parent e1d171f commit 7d252ad

File tree

25 files changed

+115
-52
lines changed

25 files changed

+115
-52
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.45"
4+
version = "0.6.46"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f)
5454
end
5555

5656
@inline function _translate(
57-
backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache
57+
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.PrepContext}
5858
) where {B}
5959
if B == 1
6060
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function _translate(
8989
end
9090
function _translate(::Type{D}, c::DI.Cache) where {D<:Dual}
9191
c0 = DI.unwrap(c)
92-
return similar(c0, D)
92+
return DI.recursive_similar(c0, D)
9393
end
9494

9595
function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
@@ -106,7 +106,7 @@ function _translate_toprep(
106106
end
107107
function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual}
108108
c0 = DI.unwrap(c)
109-
return similar(c0, D)
109+
return DI.recursive_similar(c0, D)
110110
end
111111

112112
function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,23 @@ import DifferentiationInterface as DI
55
using SparseConnectivityTracer:
66
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer
77

8-
@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c)
9-
@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray})
10-
return jacobian_buffer(DI.unwrap(c), detector)
8+
@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c)
9+
@inline function _translate(::Type{T}, c::DI.Cache) where {T}
10+
return DI.recursive_similar(DI.unwrap(c), T)
1111
end
1212

13-
function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
13+
function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
14+
T = eltype(jacobian_buffer(x, detector))
1415
new_contexts = map(contexts) do c
15-
_jacobian_translate(detector, c)
16+
_translate(T, c)
1617
end
1718
return new_contexts
1819
end
1920

20-
@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c)
21-
@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray})
22-
return hessian_buffer(DI.unwrap(c), detector)
23-
end
24-
25-
function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
21+
function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
22+
T = eltype(hessian_buffer(x, detector))
2623
new_contexts = map(contexts) do c
27-
_hessian_translate(detector, c)
24+
_translate(T, c)
2825
end
2926
return new_contexts
3027
end
@@ -35,7 +32,7 @@ function DI.jacobian_sparsity_with_contexts(
3532
x,
3633
contexts::Vararg{DI.Context,C},
3734
) where {F,C}
38-
contexts_tracer = jacobian_translate(detector, contexts...)
35+
contexts_tracer = jacobian_translate(detector, x, contexts...)
3936
fc = DI.FixTail(f, contexts_tracer...)
4037
return jacobian_sparsity(fc, x, detector)
4138
end
@@ -47,7 +44,7 @@ function DI.jacobian_sparsity_with_contexts(
4744
x,
4845
contexts::Vararg{DI.Context,C},
4946
) where {F,C}
50-
contexts_tracer = jacobian_translate(detector, contexts...)
47+
contexts_tracer = jacobian_translate(detector, x, contexts...)
5148
fc! = DI.FixTail(f!, contexts_tracer...)
5249
return jacobian_sparsity(fc!, y, x, detector)
5350
end
@@ -58,7 +55,7 @@ function DI.hessian_sparsity_with_contexts(
5855
x,
5956
contexts::Vararg{DI.Context,C},
6057
) where {F,C}
61-
contexts_tracer = hessian_translate(detector, contexts...)
58+
contexts_tracer = hessian_translate(detector, x, contexts...)
6259
fc = DI.FixTail(f, contexts_tracer...)
6360
return hessian_sparsity(fc, x, detector)
6461
end

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ DI.check_available(::AutoZygote) = true
1717
DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()
1818

1919
translate(c::DI.Context) = DI.unwrap(c)
20-
translate(c::DI.Cache) = Buffer(DI.unwrap(c))
20+
translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c))
21+
function translate(c::DI.Cache{<:Union{Tuple,NamedTuple}})
22+
return map(translate, map(DI.Cache, DI.unwrap(c)))
23+
end
2124

2225
## Pullback
2326

DifferentiationInterface/src/utils/context.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ Abstract supertype for additional context arguments, which can be passed to diff
2323
abstract type Context end
2424

2525
abstract type GeneralizedConstant <: Context end
26-
abstract type GeneralizedCache <: Context end
2726

2827
unwrap(c::Context) = c.data
2928
Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2)
@@ -78,7 +77,7 @@ The initial values present inside the cache do not matter.
7877
For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl).
7978
8079
!!! warning
81-
Most backends require any `Cache` context to be an `AbstractArray`.
80+
Some backends require any `Cache` context to be an `AbstractArray`, others accept nested (named) tuples of `AbstractArray`s.
8281
8382
# Example
8483
@@ -97,7 +96,7 @@ julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2)))
9796
1.0
9897
````
9998
"""
100-
struct Cache{T} <: GeneralizedCache
99+
struct Cache{T} <: Context
101100
data::T
102101
end
103102

@@ -114,12 +113,10 @@ struct BackendContext{T} <: GeneralizedConstant
114113
data::T
115114
end
116115

117-
struct PrepContext{T} <: GeneralizedCache
116+
struct PrepContext{T} <: Context
118117
data::T
119118
end
120119

121-
struct UnknownContext <: Context end
122-
123120
## Context manipulation
124121

125122
struct Rewrap{C,T}
@@ -146,4 +143,4 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N}
146143
end
147144

148145
adapt_eltype(c::Constant, ::Type) = c
149-
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T))
146+
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T))

DifferentiationInterface/src/utils/linalg.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,15 @@ At the moment, this only returns `false` for `StaticArrays.SArray`.
1010
"""
1111
ismutable_array(::Type) = true
1212
ismutable_array(x) = ismutable_array(typeof(x))
13+
14+
"""
15+
recursive_similar(x, T)
16+
17+
Apply `similar(_, T)` recursively to `x` or its components.
18+
19+
Works if `x` is an `AbstractArray` or a (nested) `NTuple` / `NamedTuple` of `AbstractArray`s.
20+
"""
21+
recursive_similar(x::AbstractArray, ::Type{T}) where {T} = similar(x, T)
22+
function recursive_similar(x::Union{Tuple,NamedTuple}, ::Type{T}) where {T}
23+
return map(xi -> recursive_similar(xi, T), x)
24+
end

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end;
5555

5656
test_differentiation(
5757
backends[2],
58-
default_scenarios(; include_normal=false, include_cachified=true);
58+
default_scenarios(; include_normal=false, include_cachified=true, use_tuples=true);
5959
excluded=SECOND_ORDER,
6060
logging=LOGGING,
6161
)

DifferentiationInterface/test/Back/FiniteDiff/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ end
2222
@testset "Dense" begin
2323
test_differentiation(
2424
AutoFiniteDiff(),
25-
default_scenarios(; include_constantified=true, include_cachified=true);
25+
default_scenarios(;
26+
include_constantified=true, include_cachified=true, use_tuples=true
27+
);
2628
excluded=[:second_derivative, :hvp],
2729
logging=LOGGING,
2830
)

DifferentiationInterface/test/Back/FiniteDifferences/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ end
1919

2020
test_differentiation(
2121
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
22-
default_scenarios(; include_constantified=true, include_cachified=true);
22+
default_scenarios(;
23+
include_constantified=true, include_cachified=true, use_tuples=true
24+
);
2325
excluded=SECOND_ORDER,
2426
logging=LOGGING,
2527
);

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ end
3636
test_differentiation(
3737
AutoForwardDiff(),
3838
default_scenarios(;
39-
include_normal=false, include_batchified=false, include_cachified=true
39+
include_normal=false,
40+
include_batchified=false,
41+
include_cachified=true,
42+
use_tuples=true,
4043
);
4144
logging=LOGGING,
4245
)

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ end
1919

2020
test_differentiation(
2121
backends,
22-
default_scenarios(; include_constantified=true, include_cachified=true);
22+
default_scenarios(;
23+
include_constantified=true, include_cachified=true, use_tuples=true
24+
);
2325
excluded=SECOND_ORDER,
2426
logging=LOGGING,
2527
);

DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ end
2828

2929
test_differentiation(
3030
backends,
31-
default_scenarios(; include_constantified=true, include_cachified=true);
31+
default_scenarios(;
32+
include_constantified=true, include_cachified=true, use_tuples=true
33+
);
3234
logging=LOGGING,
3335
);
3436

DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ end
1717

1818
test_differentiation(
1919
AutoFastDifferentiation(),
20-
default_scenarios(; include_constantified=true, include_cachified=true);
20+
default_scenarios(;
21+
include_constantified=true, include_cachified=true, use_tuples=false
22+
);
2123
logging=LOGGING,
2224
);
2325

DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ test_differentiation(
2121

2222
test_differentiation(
2323
AutoSymbolics(),
24-
default_scenarios(; include_normal=false, include_cachified=true);
24+
default_scenarios(; include_normal=false, include_cachified=true, use_tuples=false);
2525
excluded=[:jacobian], # TODO: figure out why this fails
2626
logging=LOGGING,
2727
);

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ end
2727
@testset "Dense" begin
2828
test_differentiation(
2929
backends,
30-
default_scenarios(; include_constantified=true, include_cachified=true);
30+
default_scenarios(;
31+
include_constantified=true, include_cachified=true, use_tuples=true
32+
);
3133
excluded=[:second_derivative],
3234
logging=LOGGING,
3335
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using DifferentiationInterface: recursive_similar
2+
using Test
3+
4+
@test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32}
5+
@test recursive_similar((ones(Int, 2), ones(Bool, 3, 4)), Float32) isa
6+
Tuple{Vector{Float32},Matrix{Float32}}
7+
@test recursive_similar((a=ones(Int, 2), b=(ones(Bool, 3, 4),)), Float32) isa
8+
@NamedTuple{a::Vector{Float32}, b::Tuple{Matrix{Float32}}}
9+
@test_throws MethodError recursive_similar(1, Float32)

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ end
8686
MyAutoSparse.(
8787
vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2]))
8888
),
89-
sparse_scenarios(; include_constantified=true, include_cachified=true);
89+
sparse_scenarios(;
90+
include_constantified=true, include_cachified=true, use_tuples=true
91+
);
9092
sparsity=true,
9193
logging=LOGGING,
9294
)

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ myjl(x::Number) = x
1818
myjl(x::AbstractArray) = jl(x)
1919
myjl(x::Tuple) = map(myjl, x)
2020
myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x)))
21-
myjl(x::DI.Cache) = DI.Cache(myjl(DI.unwrap(x)))
21+
myjl(x::DI.Cache{<:AbstractArray}) = DI.Cache(myjl(DI.unwrap(x)))
22+
myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x)))
2223
myjl(::Nothing) = nothing
2324

2425
function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ end
2929

3030
mystatic(x::Tuple) = map(mystatic, x)
3131
mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x)))
32-
mystatic(x::DI.Cache) = DI.Cache(mymutablestatic(DI.unwrap(x)))
32+
mystatic(x::DI.Cache{<:AbstractArray}) = DI.Cache(mymutablestatic(DI.unwrap(x)))
33+
function mystatic(x::DI.Cache{<:Union{Tuple,NamedTuple}})
34+
return map(mystatic, map(DI.Cache, DI.unwrap(x)))
35+
end
3336
mystatic(::Nothing) = nothing
3437

3538
function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}

DifferentiationInterfaceTest/src/scenarios/default.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ function default_scenarios(;
559559
include_closurified=false,
560560
include_constantified=false,
561561
include_cachified=false,
562+
use_tuples=false,
562563
)
563564
x_ = 0.42
564565
dx_ = 3.14
@@ -635,7 +636,7 @@ function default_scenarios(;
635636
include_normal && append!(final_scens, scens)
636637
include_closurified && append!(final_scens, closurify(scens))
637638
include_constantified && append!(final_scens, constantify(scens))
638-
include_cachified && append!(final_scens, cachify(scens))
639+
include_cachified && append!(final_scens, cachify(scens; use_tuples=use_tuples))
639640

640641
return final_scens
641642
end

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ end
147147
"""
148148
constantify(scen::Scenario)
149149
150-
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument `a` by which the output is multiplied.
150+
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument by which the output is multiplied.
151151
The output and result fields are updated accordingly.
152152
"""
153153
function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
@@ -178,6 +178,11 @@ end
178178

179179
Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))")
180180

181+
(sc::StoreInCache{:out})(x, y_cache::NamedTuple) = sc(x, y_cache.useful_cache)
182+
(sc::StoreInCache{:in})(y, x, y_cache::NamedTuple) = sc(y, x, y_cache.useful_cache)
183+
(sc::StoreInCache{:out})(x, y_cache::Tuple) = sc(x, first(y_cache))
184+
(sc::StoreInCache{:in})(y, x, y_cache::Tuple) = sc(y, x, first(y_cache))
185+
181186
function (sc::StoreInCache{:out})(x, y_cache)
182187
y = sc.f(x)
183188
if y isa Number
@@ -198,16 +203,26 @@ end
198203
"""
199204
cachify(scen::Scenario)
200205
201-
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument `a` to store the result before it is returned.
206+
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument to store the result before it is returned.
207+
208+
If `tup=true` the cache is a tuple of arrays, otherwise just an array.
202209
"""
203-
function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
210+
function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl_fun}
204211
(; f,) = scen
205212
@assert isempty(scen.contexts)
206213
cache_f = StoreInCache{pl_fun}(f)
207-
y_cache = if scen.y isa Number
208-
[myzero(scen.y)]
214+
if use_tuples
215+
y_cache = if scen.y isa Number
216+
(; useful_cache=([myzero(scen.y)],), useless_cache=[myzero(scen.y)])
217+
else
218+
(; useful_cache=(mysimilar(scen.y),), useless_cache=mysimilar(scen.y))
219+
end
209220
else
210-
mysimilar(scen.y)
221+
y_cache = if scen.y isa Number
222+
[myzero(scen.y)]
223+
else
224+
mysimilar(scen.y)
225+
end
211226
end
212227
return Scenario{op,pl_op,pl_fun}(
213228
cache_f;
@@ -217,7 +232,7 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
217232
contexts=(Cache(y_cache),),
218233
res1=scen.res1,
219234
res2=scen.res2,
220-
smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller),
235+
smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller; use_tuples),
221236
name=isnothing(scen.name) ? nothing : scen.name * " [cachified]",
222237
)
223238
end
@@ -229,7 +244,7 @@ end
229244

230245
closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)
231246
constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens)
232-
cachify(scens::AbstractVector{<:Scenario}) = cachify.(scens)
247+
cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples)
233248

234249
function set_smaller(
235250
scen::Scenario{op,pl_op,pl_fun}, smaller::Scenario

0 commit comments

Comments
 (0)