Skip to content

Commit 8c234e6

Browse files
AstitvaAggarwalyebaigdalle
authored
feat: backend switching for Mooncake (#768)
* Handles backend switching for Mooncake using ChainRules * Mooncake Wrapper for substitute backends * added rules * config * splatting for dy * brackets * too easy * changes from reviews, Docs * changes from reviews - 2 * changes from reviews-1 * conflicts * conflicts-2 * Update differentiate_with.jl * typecheck for array rule. * assertion for array inputs * extensive tests, diffwith for tuples * tests. * tests, inc primal handling * changes from reviews * Apply suggestions from code review * Simplify Mooncake rule tests, add ChainRules rule tests * Format * Update differentiate_with.jl * Restrict to array of numbers * Update DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent eeb4c86 commit 8c234e6

File tree

7 files changed

+188
-10
lines changed

7 files changed

+188
-10
lines changed

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
9595
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
9696
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
9797
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
98-
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend.
98+
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
9999

100100
## Implementations
101101

DifferentiationInterface/docs/src/faq/differentiability.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,5 @@ There are, however, translation utilities:
111111
### Backend switch
112112

113113
Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend.
114-
Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object.
114+
115+
Right now, it only targets [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](), [ChainRules.jl](https://juliadiff.org/ChainRulesCore.jl/stable/)-compatible backends (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)), but PRs are welcome to define Enzyme.jl rules for this object.

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
22
(; f, backend) = dw
33
y = f(x)
4-
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,))
4+
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,))
55
function pullbackfunc(dy)
66
tx = DI.pullback(f, prep_same, backend, x, (dy,))
77
return (NoTangent(), only(tx))

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module DifferentiationInterfaceMooncakeExt
33
using ADTypes: ADTypes, AutoMooncake
44
import DifferentiationInterface as DI
55
using Mooncake:
6+
Mooncake,
67
CoDual,
78
Config,
89
prepare_gradient_cache,
@@ -11,6 +12,16 @@ using Mooncake:
1112
value_and_gradient!!,
1213
value_and_pullback!!,
1314
zero_tangent,
15+
rdata_type,
16+
fdata,
17+
rdata,
18+
tangent_type,
19+
NoTangent,
20+
@is_primitive,
21+
zero_fcodual,
22+
MinimalCtx,
23+
NoRData,
24+
primal,
1425
_copy_output,
1526
_copy_to_output!!
1627

@@ -25,5 +36,6 @@ mycopy(x) = deepcopy(x)
2536

2637
include("onearg.jl")
2738
include("twoarg.jl")
39+
include("differentiate_with.jl")
2840

2941
end
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Any}
2+
3+
struct MooncakeDifferentiateWithError <: Exception
4+
F::Type
5+
X::Type
6+
Y::Type
7+
function MooncakeDifferentiateWithError(::F, ::X, ::Y) where {F,X,Y}
8+
return new(F, X, Y)
9+
end
10+
end
11+
12+
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
13+
return print(
14+
io,
15+
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.",
16+
)
17+
end
18+
19+
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
20+
primal_func = primal(dw)
21+
primal_x = primal(x)
22+
(; f, backend) = primal_func
23+
y = zero_fcodual(f(primal_x))
24+
25+
# output is a vector, so we need to use the vector pullback
26+
function pullback_array!!(dy::NoRData)
27+
tx = DI.pullback(f, backend, primal_x, (y.dx,))
28+
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
29+
return NoRData(), rdata(only(tx))
30+
end
31+
32+
# output is a scalar, so we can use the scalar pullback
33+
function pullback_scalar!!(dy::Number)
34+
tx = DI.pullback(f, backend, primal_x, (dy,))
35+
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
36+
return NoRData(), rdata(only(tx))
37+
end
38+
39+
pullback = if primal(y) isa Number
40+
pullback_scalar!!
41+
elseif primal(y) isa AbstractArray
42+
pullback_array!!
43+
else
44+
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
45+
end
46+
47+
return y, pullback
48+
end
49+
50+
function Mooncake.rrule!!(
51+
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}}
52+
)
53+
primal_func = primal(dw)
54+
primal_x = primal(x)
55+
fdata_arg = x.dx
56+
(; f, backend) = primal_func
57+
y = zero_fcodual(f(primal_x))
58+
59+
# output is a vector, so we need to use the vector pullback
60+
function pullback_array!!(dy::NoRData)
61+
tx = DI.pullback(f, backend, primal_x, (y.dx,))
62+
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
63+
fdata_arg .+= only(tx)
64+
return NoRData(), dy
65+
end
66+
67+
# output is a scalar, so we can use the scalar pullback
68+
function pullback_scalar!!(dy::Number)
69+
tx = DI.pullback(f, backend, primal_x, (dy,))
70+
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
71+
fdata_arg .+= only(tx)
72+
return NoRData(), NoRData()
73+
end
74+
75+
pullback = if primal(y) isa Number
76+
pullback_scalar!!
77+
elseif primal(y) isa AbstractArray
78+
pullback_array!!
79+
else
80+
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
81+
end
82+
83+
return y, pullback
84+
end

DifferentiationInterface/src/misc/differentiate_with.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be
1313
1414
!!! warning
1515
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
16-
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
16+
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
1717
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).
1818
19+
!!! warning
20+
When using `DifferentiateWith(f, AutoSomething())`, the function `f` must not close over any active data.
21+
As of now, we cannot differentiate with respect to parameters stored inside `f`.
22+
1923
# Fields
2024
2125
- `f`: the function in question, with signature `f(x)`
Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,106 @@
11
using Pkg
2-
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"])
2+
Pkg.add(["ChainRulesTestUtils", "FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])
33

4+
using ChainRulesTestUtils: ChainRulesTestUtils
45
using DifferentiationInterface, DifferentiationInterfaceTest
56
import DifferentiationInterfaceTest as DIT
67
using FiniteDiff: FiniteDiff
78
using ForwardDiff: ForwardDiff
89
using Zygote: Zygote
10+
using Mooncake: Mooncake
11+
using StableRNGs
912
using Test
1013

1114
LOGGING = get(ENV, "CI", "false") == "false"
1215

16+
struct ADBreaker{F}
17+
f::F
18+
end
19+
20+
function (adb::ADBreaker)(x::Number)
21+
copyto!(Float64[0], x) # break ForwardDiff and Zygote
22+
return adb.f(x)
23+
end
24+
25+
function (adb::ADBreaker)(x::AbstractArray)
26+
copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote
27+
return adb.f(x)
28+
end
29+
1330
function differentiatewith_scenarios()
14-
bad_scens = # these closurified scenarios have mutation and type constraints
15-
filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen
16-
DIT.function_place(scen) == :out
17-
end
31+
outofplace_scens = filter(DIT.default_scenarios()) do scen
32+
DIT.function_place(scen) == :out
33+
end
34+
# with bad_scens, everything would break
35+
bad_scens = map(outofplace_scens) do scen
36+
DIT.change_function(scen, ADBreaker(scen.f))
37+
end
38+
# with good_scens, everything is fixed
1839
good_scens = map(bad_scens) do scen
1940
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
2041
end
2142
return good_scens
2243
end
2344

2445
test_differentiation(
25-
[AutoForwardDiff(), AutoZygote()],
46+
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
2647
differentiatewith_scenarios();
2748
excluded=SECOND_ORDER,
2849
logging=LOGGING,
50+
testset_name="DI tests",
2951
)
52+
53+
@testset "ChainRules tests" begin
54+
@testset for scen in filter(differentiatewith_scenarios()) do scen
55+
DIT.operator(scen) == :pullback
56+
end
57+
ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4)
58+
end
59+
end;
60+
61+
@testset "Mooncake tests" begin
62+
@testset for scen in filter(differentiatewith_scenarios()) do scen
63+
DIT.operator(scen) == :pullback
64+
end
65+
Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true)
66+
end
67+
end;
68+
69+
@testset "Mooncake errors" begin
70+
MooncakeDifferentiateWithError =
71+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError
72+
73+
e = MooncakeDifferentiateWithError(identity, 1.0, 2.0)
74+
@test sprint(showerror, e) ==
75+
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported."
76+
77+
f_num2tup(x::Number) = (x,)
78+
f_vec2tup(x::Vector) = (first(x),)
79+
f_tup2num(x::Tuple{<:Number}) = only(x)
80+
f_tup2vec(x::Tuple{<:Number}) = [only(x)]
81+
82+
@test_throws MooncakeDifferentiateWithError pullback(
83+
DifferentiateWith(f_num2tup, AutoFiniteDiff()),
84+
AutoMooncake(; config=nothing),
85+
1.0,
86+
((2.0,),),
87+
)
88+
@test_throws MooncakeDifferentiateWithError pullback(
89+
DifferentiateWith(f_vec2tup, AutoFiniteDiff()),
90+
AutoMooncake(; config=nothing),
91+
[1.0],
92+
((2.0,),),
93+
)
94+
@test_throws MethodError pullback(
95+
DifferentiateWith(f_tup2num, AutoFiniteDiff()),
96+
AutoMooncake(; config=nothing),
97+
(1.0,),
98+
(2.0,),
99+
)
100+
@test_throws MethodError pullback(
101+
DifferentiateWith(f_tup2vec, AutoFiniteDiff()),
102+
AutoMooncake(; config=nothing),
103+
(1.0,),
104+
([2.0],),
105+
)
106+
end

0 commit comments

Comments
 (0)