Skip to content

Commit 4e694bb

Browse files
authored
fix: replace copy with deepcopy in Mooncake pullbacks (#723)
1 parent e515499 commit 4e694bb

File tree

4 files changed

+97
-29
lines changed

4 files changed

+97
-29
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ ForwardDiff = "0.10.36"
6262
GTPSA = "1.4.0"
6363
JuliaFormatter = "1"
6464
LinearAlgebra = "<0.0.1,1"
65-
Mooncake = "0.4.52"
65+
Mooncake = "0.4.83"
6666
PolyesterForwardDiff = "0.1.2"
6767
ReverseDiff = "1.15.1"
6868
SparseArrays = "<0.0.1,1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Mooncake:
99
tangent,
1010
tangent_type,
1111
value_and_pullback!!,
12+
value_and_gradient!!,
1213
zero_tangent,
1314
prepare_pullback_cache,
1415
Mooncake
@@ -21,6 +22,10 @@ copyto!!(dst, src) = DI.ismutable_array(dst) ? copyto!(dst, src) : convert(typeo
2122
get_config(::AutoMooncake{Nothing}) = Config()
2223
get_config(backend::AutoMooncake{<:Config}) = backend.config
2324

25+
# tangents need to be copied before returning, otherwise they are still aliased in the cache
26+
mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x)
27+
mycopy(x) = deepcopy(x)
28+
2429
include("onearg.jl")
2530
include("twoarg.jl")
2631

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
## Pullback
2+
13
struct MooncakeOneArgPullbackPrep{Tcache,DY} <: DI.PullbackPrep
24
cache::Tcache
35
dy_righttype::DY
46
end
57

68
function DI.prepare_pullback(
7-
f, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}
8-
) where {C}
9+
f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}
10+
) where {F,C}
911
config = get_config(backend)
1012
cache = prepare_pullback_cache(
1113
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
@@ -27,34 +29,34 @@ function DI.value_and_pullback(
2729
) where {F,Y,C}
2830
dy = only(ty)
2931
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
30-
new_y, (_, new_dx) = Mooncake.value_and_pullback!!(
32+
new_y, (_, new_dx) = value_and_pullback!!(
3133
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
3234
)
33-
return new_y, (copy(new_dx),)
35+
return new_y, (mycopy(new_dx),)
3436
end
3537

3638
function DI.value_and_pullback!(
37-
f,
39+
f::F,
3840
tx::NTuple{1},
3941
prep::MooncakeOneArgPullbackPrep{Y},
4042
backend::AutoMooncake,
4143
x,
4244
ty::NTuple{1},
4345
contexts::Vararg{DI.Context,C},
44-
) where {Y,C}
46+
) where {F,Y,C}
4547
y, (new_dx,) = DI.value_and_pullback(f, prep, backend, x, ty, contexts...)
4648
copyto!(only(tx), new_dx)
4749
return y, tx
4850
end
4951

5052
function DI.value_and_pullback(
51-
f,
53+
f::F,
5254
prep::MooncakeOneArgPullbackPrep,
5355
backend::AutoMooncake,
5456
x,
5557
ty::NTuple,
5658
contexts::Vararg{DI.Context,C},
57-
) where {C}
59+
) where {F,C}
5860
ys_and_tx = map(ty) do dy
5961
y, tx = DI.value_and_pullback(f, prep, backend, x, (dy,), contexts...)
6062
y, only(tx)
@@ -65,14 +67,14 @@ function DI.value_and_pullback(
6567
end
6668

6769
function DI.value_and_pullback!(
68-
f,
70+
f::F,
6971
tx::NTuple,
7072
prep::MooncakeOneArgPullbackPrep,
7173
backend::AutoMooncake,
7274
x,
7375
ty::NTuple,
7476
contexts::Vararg{DI.Context,C},
75-
) where {C}
77+
) where {F,C}
7678
ys = map(tx, ty) do dx, dy
7779
y, _ = DI.value_and_pullback!(f, (dx,), prep, backend, x, (dy,), contexts...)
7880
y
@@ -82,24 +84,85 @@ function DI.value_and_pullback!(
8284
end
8385

8486
function DI.pullback(
85-
f,
87+
f::F,
8688
prep::MooncakeOneArgPullbackPrep,
8789
backend::AutoMooncake,
8890
x,
8991
ty::NTuple,
9092
contexts::Vararg{DI.Context,C},
91-
) where {C}
93+
) where {F,C}
9294
return DI.value_and_pullback(f, prep, backend, x, ty, contexts...)[2]
9395
end
9496

9597
function DI.pullback!(
96-
f,
98+
f::F,
9799
tx::NTuple,
98100
prep::MooncakeOneArgPullbackPrep,
99101
backend::AutoMooncake,
100102
x,
101103
ty::NTuple,
102104
contexts::Vararg{DI.Context,C},
103-
) where {C}
105+
) where {F,C}
104106
return DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2]
105107
end
108+
109+
## Gradient
110+
111+
struct MooncakeGradientPrep{Tcache} <: DI.GradientPrep
112+
cache::Tcache
113+
end
114+
115+
function DI.prepare_gradient(
116+
f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C}
117+
) where {F,C}
118+
config = get_config(backend)
119+
cache = prepare_pullback_cache(
120+
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
121+
)
122+
prep = MooncakeGradientPrep(cache)
123+
DI.value_and_gradient(f, prep, backend, x, contexts...)
124+
return prep
125+
end
126+
127+
function DI.value_and_gradient(
128+
f::F, prep::MooncakeGradientPrep, ::AutoMooncake, x, contexts::Vararg{DI.Context,C}
129+
) where {F,C}
130+
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
131+
return y, mycopy(new_grad)
132+
end
133+
134+
function DI.value_and_gradient!(
135+
f::F,
136+
grad,
137+
prep::MooncakeGradientPrep,
138+
::AutoMooncake,
139+
x,
140+
contexts::Vararg{DI.Context,C},
141+
) where {F,C}
142+
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
143+
copyto!(grad, new_grad)
144+
return y, grad
145+
end
146+
147+
function DI.gradient(
148+
f::F,
149+
prep::MooncakeGradientPrep,
150+
backend::AutoMooncake,
151+
x,
152+
contexts::Vararg{DI.Context,C},
153+
) where {F,C}
154+
_, grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
155+
return grad
156+
end
157+
158+
function DI.gradient!(
159+
f::F,
160+
grad,
161+
prep::MooncakeGradientPrep,
162+
backend::AutoMooncake,
163+
x,
164+
contexts::Vararg{DI.Context,C},
165+
) where {F,C}
166+
DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
167+
return grad
168+
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ struct MooncakeTwoArgPullbackPrep{Tcache,DY,F} <: DI.PullbackPrep
55
end
66

77
function DI.prepare_pullback(
8-
f!, y, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}
9-
) where {C}
8+
f!::F, y, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}
9+
) where {F,C}
1010
target_function = function (f!, y, x, contexts...)
1111
f!(y, x, contexts...)
1212
return y
@@ -26,36 +26,36 @@ function DI.prepare_pullback(
2626
end
2727

2828
function DI.value_and_pullback(
29-
f!,
29+
f!::F,
3030
y,
3131
prep::MooncakeTwoArgPullbackPrep,
3232
::AutoMooncake,
3333
x,
3434
ty::NTuple{1},
3535
contexts::Vararg{DI.Context,C},
36-
) where {C}
36+
) where {F,C}
3737
# Prepare cotangent to add after the forward pass.
3838
dy = only(ty)
3939
dy_righttype_after = copyto!(prep.dy_righttype, dy)
4040

4141
# Run the reverse-pass and return the results.
4242
contexts = map(DI.unwrap, contexts)
43-
y_after, (_, _, _, dx) = Mooncake.value_and_pullback!!(
43+
y_after, (_, _, _, dx) = value_and_pullback!!(
4444
prep.cache, dy_righttype_after, prep.target_function, f!, y, x, contexts...
4545
)
4646
copyto!(y, y_after)
47-
return y, (copy(dx),) # TODO: remove this allocation in `value_and_pullback!`
47+
return y, (mycopy(dx),)
4848
end
4949

5050
function DI.value_and_pullback(
51-
f!,
51+
f!::F,
5252
y,
5353
prep::MooncakeTwoArgPullbackPrep,
5454
backend::AutoMooncake,
5555
x,
5656
ty::NTuple,
5757
contexts::Vararg{DI.Context,C},
58-
) where {C}
58+
) where {F,C}
5959
tx = map(ty) do dy
6060
_, tx = DI.value_and_pullback(f!, y, prep, backend, x, (dy,), contexts...)
6161
only(tx)
@@ -64,41 +64,41 @@ function DI.value_and_pullback(
6464
end
6565

6666
function DI.value_and_pullback!(
67-
f!,
67+
f!::F,
6868
y,
6969
tx::NTuple,
7070
prep::MooncakeTwoArgPullbackPrep,
7171
backend::AutoMooncake,
7272
x,
7373
ty::NTuple,
7474
contexts::Vararg{DI.Context,C},
75-
) where {C}
75+
) where {F,C}
7676
_, new_tx = DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)
7777
foreach(copyto!, tx, new_tx)
7878
return y, tx
7979
end
8080

8181
function DI.pullback(
82-
f!,
82+
f!::F,
8383
y,
8484
prep::MooncakeTwoArgPullbackPrep,
8585
backend::AutoMooncake,
8686
x,
8787
ty::NTuple,
8888
contexts::Vararg{DI.Context,C},
89-
) where {C}
89+
) where {F,C}
9090
return DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)[2]
9191
end
9292

9393
function DI.pullback!(
94-
f!,
94+
f!::F,
9595
y,
9696
tx::NTuple,
9797
prep::MooncakeTwoArgPullbackPrep,
9898
backend::AutoMooncake,
9999
x,
100100
ty::NTuple,
101101
contexts::Vararg{DI.Context,C},
102-
) where {C}
102+
) where {F,C}
103103
return DI.value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...)[2]
104104
end

0 commit comments

Comments
 (0)