Skip to content

feat: backend switching for Mooncake #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1a389a6
Handles backend switching for Mooncake using ChainRules
AstitvaAggarwal Apr 1, 2025
08b176a
Mooncake Wrapper for substitute backends
AstitvaAggarwal Apr 2, 2025
ba0c9e6
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal Apr 10, 2025
1340d92
added rules
AstitvaAggarwal Apr 10, 2025
2ce1ee2
Merge branch 'develop' of https://github.com/AstitvaAggarwal/Differen…
AstitvaAggarwal Apr 10, 2025
08de6df
config
AstitvaAggarwal Apr 10, 2025
84f27c9
splatting for dy
AstitvaAggarwal Apr 10, 2025
2e95299
brackets
AstitvaAggarwal Apr 10, 2025
13233e5
too easy
AstitvaAggarwal Apr 11, 2025
1e8df98
changes from reviews, Docs
AstitvaAggarwal Apr 12, 2025
afdddd4
changes from reviews - 2
AstitvaAggarwal Apr 18, 2025
233c312
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal Apr 18, 2025
7a07127
changes from reviews-1
AstitvaAggarwal May 16, 2025
f3e436d
conflicts
AstitvaAggarwal May 16, 2025
6a0d937
conflicts-2
AstitvaAggarwal May 16, 2025
e543958
Update differentiate_with.jl
AstitvaAggarwal May 16, 2025
2472ecc
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal May 16, 2025
c63c956
typecheck for array rule.
AstitvaAggarwal May 18, 2025
36da036
assertion for array inputs
AstitvaAggarwal May 18, 2025
d2b5a8c
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal May 29, 2025
c389a80
extensive tests, diffwith for tuples
AstitvaAggarwal May 29, 2025
b4fe0f8
tests.
AstitvaAggarwal May 29, 2025
ec4b75d
tests, inc primal handling
AstitvaAggarwal May 31, 2025
0f0b9fc
changes from reviews
AstitvaAggarwal Jun 6, 2025
3c5f99e
Merge branch 'main' into develop
yebai Jun 13, 2025
d94f146
Apply suggestions from code review
gdalle Jun 13, 2025
c982f46
Simplify Mooncake rule tests, add ChainRules rule tests
gdalle Jun 13, 2025
749fea5
Format
gdalle Jun 13, 2025
9e5ecfd
Update differentiate_with.jl
gdalle Jun 14, 2025
1e85f17
Restrict to array of numbers
gdalle Jun 14, 2025
ff5c4e2
Update DifferentiationInterface/ext/DifferentiationInterfaceMooncakeE…
gdalle Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).

## Implementations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,5 @@ There are, however, translation utilities:
### Backend switch

Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend.
Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object.

Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,))
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,))
function pullbackfunc(dy)
tx = DI.pullback(f, prep_same, backend, x, (dy,))
return (NoTangent(), only(tx))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module DifferentiationInterfaceMooncakeExt
using ADTypes: ADTypes, AutoMooncake
import DifferentiationInterface as DI
using Mooncake:
Mooncake,
CoDual,
Config,
prepare_gradient_cache,
Expand All @@ -11,6 +12,16 @@ using Mooncake:
value_and_gradient!!,
value_and_pullback!!,
zero_tangent,
rdata_type,
fdata,
rdata,
tangent_type,
NoTangent,
@is_primitive,
zero_fcodual,
MinimalCtx,
NoRData,
primal,
_copy_output,
_copy_to_output!!

Expand All @@ -25,5 +36,6 @@ mycopy(x) = deepcopy(x)

include("onearg.jl")
include("twoarg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Any}

struct MooncakeDifferentiateWithError <: Exception
F::Type
X::Type
Y::Type
function MooncakeDifferentiateWithError(::F, ::X, ::Y) where {F,X,Y}
return new(F, X, Y)
end
end

function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
return print(
io,
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.",
)
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
primal_func = primal(dw)
primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
end

return y, pullback
end

function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}}
)
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = x.dx
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), dy
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), NoRData()
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
end

return y, pullback
end
6 changes: 5 additions & 1 deletion DifferentiationInterface/src/misc/differentiate_with.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be

!!! warning
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).

!!! warning
When using `DifferentiateWith(f, AutoSomething())`, the function `f` must not close over any active data.
As of now, we cannot differentiate with respect to parameters stored inside `f`.

# Fields

- `f`: the function in question, with signature `f(x)`
Expand Down
89 changes: 83 additions & 6 deletions DifferentiationInterface/test/Back/DifferentiateWith/test.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,106 @@
using Pkg
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"])
Pkg.add(["ChainRulesTestUtils", "FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])

using ChainRulesTestUtils: ChainRulesTestUtils
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
using StableRNGs
using Test

LOGGING = get(ENV, "CI", "false") == "false"

struct ADBreaker{F}
f::F
end

function (adb::ADBreaker)(x::Number)
copyto!(Float64[0], x) # break ForwardDiff and Zygote
return adb.f(x)
end

function (adb::ADBreaker)(x::AbstractArray)
copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote
return adb.f(x)
end

function differentiatewith_scenarios()
bad_scens = # these closurified scenarios have mutation and type constraints
filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen
DIT.function_place(scen) == :out
end
outofplace_scens = filter(DIT.default_scenarios()) do scen
DIT.function_place(scen) == :out
end
# with bad_scens, everything would break
bad_scens = map(outofplace_scens) do scen
DIT.change_function(scen, ADBreaker(scen.f))
end
# with good_scens, everything is fixed
good_scens = map(bad_scens) do scen
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
end
return good_scens
end

test_differentiation(
[AutoForwardDiff(), AutoZygote()],
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
differentiatewith_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,
testset_name="DI tests",
)

@testset "ChainRules tests" begin
@testset for scen in filter(differentiatewith_scenarios()) do scen
DIT.operator(scen) == :pullback
end
ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4)
end
end;

@testset "Mooncake tests" begin
@testset for scen in filter(differentiatewith_scenarios()) do scen
DIT.operator(scen) == :pullback
end
Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true)
end
end;

@testset "Mooncake errors" begin
MooncakeDifferentiateWithError =
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError

e = MooncakeDifferentiateWithError(identity, 1.0, 2.0)
@test sprint(showerror, e) ==
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported."

f_num2tup(x::Number) = (x,)
f_vec2tup(x::Vector) = (first(x),)
f_tup2num(x::Tuple{<:Number}) = only(x)
f_tup2vec(x::Tuple{<:Number}) = [only(x)]

@test_throws MooncakeDifferentiateWithError pullback(
DifferentiateWith(f_num2tup, AutoFiniteDiff()),
AutoMooncake(; config=nothing),
1.0,
((2.0,),),
)
@test_throws MooncakeDifferentiateWithError pullback(
DifferentiateWith(f_vec2tup, AutoFiniteDiff()),
AutoMooncake(; config=nothing),
[1.0],
((2.0,),),
)
@test_throws MethodError pullback(
DifferentiateWith(f_tup2num, AutoFiniteDiff()),
AutoMooncake(; config=nothing),
(1.0,),
(2.0,),
)
@test_throws MethodError pullback(
DifferentiateWith(f_tup2vec, AutoFiniteDiff()),
AutoMooncake(; config=nothing),
(1.0,),
([2.0],),
)
end