Skip to content

fix: replace copy with deepcopy in Mooncake pullbacks #723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 9, 2025
Merged
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ ForwardDiff = "0.10.36"
GTPSA = "1.4.0"
JuliaFormatter = "1"
LinearAlgebra = "<0.0.1,1"
Mooncake = "0.4.52"
Mooncake = "0.4.83"
PolyesterForwardDiff = "0.1.2"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
tangent,
tangent_type,
value_and_pullback!!,
value_and_gradient!!,
zero_tangent,
prepare_pullback_cache,
Mooncake
Expand All @@ -21,6 +22,10 @@
get_config(::AutoMooncake{Nothing}) = Config()
get_config(backend::AutoMooncake{<:Config}) = backend.config

# tangents need to be copied before returning, otherwise they are still aliased in the cache
mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x)
mycopy(x) = deepcopy(x)

Check warning on line 27 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl#L27

Added line #L27 was not covered by tests

include("onearg.jl")
include("twoarg.jl")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
## Pullback

struct MooncakeOneArgPullbackPrep{Tcache,DY} <: DI.PullbackPrep
cache::Tcache
dy_righttype::DY
end

function DI.prepare_pullback(
f, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}
) where {C}
f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}
) where {F,C}
config = get_config(backend)
cache = prepare_pullback_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
Expand All @@ -27,34 +29,34 @@ function DI.value_and_pullback(
) where {F,Y,C}
dy = only(ty)
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
new_y, (_, new_dx) = Mooncake.value_and_pullback!!(
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
)
return new_y, (copy(new_dx),)
return new_y, (mycopy(new_dx),)
end

function DI.value_and_pullback!(
f,
f::F,
tx::NTuple{1},
prep::MooncakeOneArgPullbackPrep{Y},
backend::AutoMooncake,
x,
ty::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {Y,C}
) where {F,Y,C}
y, (new_dx,) = DI.value_and_pullback(f, prep, backend, x, ty, contexts...)
copyto!(only(tx), new_dx)
return y, tx
end

function DI.value_and_pullback(
f,
f::F,
prep::MooncakeOneArgPullbackPrep,
backend::AutoMooncake,
x,
ty::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
ys_and_tx = map(ty) do dy
y, tx = DI.value_and_pullback(f, prep, backend, x, (dy,), contexts...)
y, only(tx)
Expand All @@ -65,14 +67,14 @@ function DI.value_and_pullback(
end

function DI.value_and_pullback!(
f,
f::F,
tx::NTuple,
prep::MooncakeOneArgPullbackPrep,
backend::AutoMooncake,
x,
ty::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
ys = map(tx, ty) do dx, dy
y, _ = DI.value_and_pullback!(f, (dx,), prep, backend, x, (dy,), contexts...)
y
Expand All @@ -82,24 +84,85 @@ function DI.value_and_pullback!(
end

function DI.pullback(
f,
f::F,
prep::MooncakeOneArgPullbackPrep,
backend::AutoMooncake,
x,
ty::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
return DI.value_and_pullback(f, prep, backend, x, ty, contexts...)[2]
end

function DI.pullback!(
f,
f::F,
tx::NTuple,
prep::MooncakeOneArgPullbackPrep,
backend::AutoMooncake,
x,
ty::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
return DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2]
end

## Gradient

struct MooncakeGradientPrep{Tcache} <: DI.GradientPrep
cache::Tcache
end

function DI.prepare_gradient(
f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C}
) where {F,C}
config = get_config(backend)
cache = prepare_pullback_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
prep = MooncakeGradientPrep(cache)
DI.value_and_gradient(f, prep, backend, x, contexts...)
return prep
end

function DI.value_and_gradient(
f::F, prep::MooncakeGradientPrep, ::AutoMooncake, x, contexts::Vararg{DI.Context,C}
) where {F,C}
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
return y, mycopy(new_grad)
end

function DI.value_and_gradient!(
f::F,
grad,
prep::MooncakeGradientPrep,
::AutoMooncake,
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
copyto!(grad, new_grad)
return y, grad
end

function DI.gradient(
f::F,
prep::MooncakeGradientPrep,
backend::AutoMooncake,
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
_, grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
return grad
end

function DI.gradient!(
f::F,
grad,
prep::MooncakeGradientPrep,
backend::AutoMooncake,
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
return grad
end
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ struct MooncakeTwoArgPullbackPrep{Tcache,DY,F} <: DI.PullbackPrep
end

function DI.prepare_pullback(
f!, y, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}
) where {C}
f!::F, y, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}
) where {F,C}
target_function = function (f!, y, x, contexts...)
f!(y, x, contexts...)
return y
Expand All @@ -26,36 +26,36 @@ function DI.prepare_pullback(
end

function DI.value_and_pullback(
f!,
f!::F,
y,
prep::MooncakeTwoArgPullbackPrep,
::AutoMooncake,
x,
ty::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
# Prepare cotangent to add after the forward pass.
dy = only(ty)
dy_righttype_after = copyto!(prep.dy_righttype, dy)

# Run the reverse-pass and return the results.
contexts = map(DI.unwrap, contexts)
y_after, (_, _, _, dx) = Mooncake.value_and_pullback!!(
y_after, (_, _, _, dx) = value_and_pullback!!(
prep.cache, dy_righttype_after, prep.target_function, f!, y, x, contexts...
)
copyto!(y, y_after)
return y, (copy(dx),) # TODO: remove this allocation in `value_and_pullback!`
return y, (mycopy(dx),)
end

function DI.value_and_pullback(
f!,
f!::F,
y,
prep::MooncakeTwoArgPullbackPrep,
backend::AutoMooncake,
x,
ty::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
tx = map(ty) do dy
_, tx = DI.value_and_pullback(f!, y, prep, backend, x, (dy,), contexts...)
only(tx)
Expand All @@ -64,41 +64,41 @@ function DI.value_and_pullback(
end

function DI.value_and_pullback!(
f!,
f!::F,
y,
tx::NTuple,
prep::MooncakeTwoArgPullbackPrep,
backend::AutoMooncake,
x,
ty::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
_, new_tx = DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)
foreach(copyto!, tx, new_tx)
return y, tx
end

function DI.pullback(
f!,
f!::F,
y,
prep::MooncakeTwoArgPullbackPrep,
backend::AutoMooncake,
x,
ty::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
return DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)[2]
end

function DI.pullback!(
f!,
f!::F,
y,
tx::NTuple,
prep::MooncakeTwoArgPullbackPrep,
backend::AutoMooncake,
x,
ty::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
) where {F,C}
return DI.value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...)[2]
end