-
Notifications
You must be signed in to change notification settings - Fork 25
feat: backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
1a389a6
Handles backend switching for Mooncake using ChainRules
AstitvaAggarwal 08b176a
Mooncake Wrapper for substitute backends
AstitvaAggarwal ba0c9e6
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal 1340d92
added rules
AstitvaAggarwal 2ce1ee2
Merge branch 'develop' of https://github.com/AstitvaAggarwal/Differen…
AstitvaAggarwal 08de6df
config
AstitvaAggarwal 84f27c9
splatting for dy
AstitvaAggarwal 2e95299
brackets
AstitvaAggarwal 13233e5
too easy
AstitvaAggarwal 1e8df98
changes from reviews, Docs
AstitvaAggarwal afdddd4
changes from reviews - 2
AstitvaAggarwal 233c312
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal 7a07127
changes from reviews-1
AstitvaAggarwal f3e436d
conflicts
AstitvaAggarwal 6a0d937
conflicts-2
AstitvaAggarwal e543958
Update differentiate_with.jl
AstitvaAggarwal 2472ecc
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal c63c956
typecheck for array rule.
AstitvaAggarwal 36da036
assertion for array inputs
AstitvaAggarwal d2b5a8c
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal c389a80
extensive tests, diffwith for tuples
AstitvaAggarwal b4fe0f8
tests.
AstitvaAggarwal ec4b75d
tests, inc primal handling
AstitvaAggarwal 0f0b9fc
changes from reviews
AstitvaAggarwal 3c5f99e
Merge branch 'main' into develop
yebai d94f146
Apply suggestions from code review
gdalle c982f46
Simplify Mooncake rule tests, add ChainRules rule tests
gdalle 749fea5
Format
gdalle 9e5ecfd
Update differentiate_with.jl
gdalle 1e85f17
Restrict to array of numbers
gdalle ff5c4e2
Update DifferentiationInterface/ext/DifferentiationInterfaceMooncakeE…
gdalle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Any} | ||
|
||
struct MooncakeDifferentiateWithError <: Exception | ||
F::Type | ||
X::Type | ||
Y::Type | ||
function MooncakeDifferentiateWithError(::F, ::X, ::Y) where {F,X,Y} | ||
return new(F, X, Y) | ||
end | ||
end | ||
|
||
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) | ||
return print( | ||
io, | ||
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.", | ||
) | ||
end | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) | ||
primal_func = primal(dw) | ||
primal_x = primal(x) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# output is a vector, so we need to use the vector pullback | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function pullback_array!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (y.dx,)) | ||
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) | ||
return NoRData(), rdata(only(tx)) | ||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback_scalar!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) | ||
return NoRData(), rdata(only(tx)) | ||
end | ||
|
||
pullback = if primal(y) isa Number | ||
pullback_scalar!! | ||
elseif primal(y) isa AbstractArray | ||
pullback_array!! | ||
else | ||
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) | ||
end | ||
|
||
return y, pullback | ||
end | ||
|
||
function Mooncake.rrule!!( | ||
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}} | ||
) | ||
primal_func = primal(dw) | ||
primal_x = primal(x) | ||
fdata_arg = x.dx | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
|
||
# output is a vector, so we need to use the vector pullback | ||
function pullback_array!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (y.dx,)) | ||
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), dy | ||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback_scalar!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), NoRData() | ||
end | ||
|
||
pullback = if primal(y) isa Number | ||
pullback_scalar!! | ||
elseif primal(y) isa AbstractArray | ||
pullback_array!! | ||
else | ||
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) | ||
end | ||
|
||
return y, pullback | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 83 additions & 6 deletions
89
DifferentiationInterface/test/Back/DifferentiateWith/test.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,106 @@ | ||
using Pkg | ||
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"]) | ||
Pkg.add(["ChainRulesTestUtils", "FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) | ||
|
||
using ChainRulesTestUtils: ChainRulesTestUtils | ||
using DifferentiationInterface, DifferentiationInterfaceTest | ||
import DifferentiationInterfaceTest as DIT | ||
using FiniteDiff: FiniteDiff | ||
using ForwardDiff: ForwardDiff | ||
using Zygote: Zygote | ||
using Mooncake: Mooncake | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
using StableRNGs | ||
using Test | ||
|
||
LOGGING = get(ENV, "CI", "false") == "false" | ||
|
||
struct ADBreaker{F} | ||
f::F | ||
end | ||
|
||
function (adb::ADBreaker)(x::Number) | ||
copyto!(Float64[0], x) # break ForwardDiff and Zygote | ||
return adb.f(x) | ||
end | ||
|
||
function (adb::ADBreaker)(x::AbstractArray) | ||
copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote | ||
return adb.f(x) | ||
end | ||
|
||
function differentiatewith_scenarios() | ||
bad_scens = # these closurified scenarios have mutation and type constraints | ||
filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen | ||
DIT.function_place(scen) == :out | ||
end | ||
outofplace_scens = filter(DIT.default_scenarios()) do scen | ||
DIT.function_place(scen) == :out | ||
end | ||
# with bad_scens, everything would break | ||
bad_scens = map(outofplace_scens) do scen | ||
DIT.change_function(scen, ADBreaker(scen.f)) | ||
end | ||
# with good_scens, everything is fixed | ||
good_scens = map(bad_scens) do scen | ||
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff())) | ||
end | ||
return good_scens | ||
end | ||
|
||
test_differentiation( | ||
[AutoForwardDiff(), AutoZygote()], | ||
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)], | ||
differentiatewith_scenarios(); | ||
excluded=SECOND_ORDER, | ||
logging=LOGGING, | ||
testset_name="DI tests", | ||
) | ||
|
||
@testset "ChainRules tests" begin | ||
@testset for scen in filter(differentiatewith_scenarios()) do scen | ||
DIT.operator(scen) == :pullback | ||
end | ||
ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4) | ||
end | ||
end; | ||
|
||
@testset "Mooncake tests" begin | ||
@testset for scen in filter(differentiatewith_scenarios()) do scen | ||
DIT.operator(scen) == :pullback | ||
end | ||
Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true) | ||
end | ||
end; | ||
|
||
@testset "Mooncake errors" begin | ||
MooncakeDifferentiateWithError = | ||
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError | ||
|
||
e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) | ||
@test sprint(showerror, e) == | ||
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported." | ||
|
||
f_num2tup(x::Number) = (x,) | ||
f_vec2tup(x::Vector) = (first(x),) | ||
f_tup2num(x::Tuple{<:Number}) = only(x) | ||
f_tup2vec(x::Tuple{<:Number}) = [only(x)] | ||
|
||
@test_throws MooncakeDifferentiateWithError pullback( | ||
DifferentiateWith(f_num2tup, AutoFiniteDiff()), | ||
AutoMooncake(; config=nothing), | ||
1.0, | ||
((2.0,),), | ||
) | ||
@test_throws MooncakeDifferentiateWithError pullback( | ||
DifferentiateWith(f_vec2tup, AutoFiniteDiff()), | ||
AutoMooncake(; config=nothing), | ||
[1.0], | ||
((2.0,),), | ||
) | ||
@test_throws MethodError pullback( | ||
DifferentiateWith(f_tup2num, AutoFiniteDiff()), | ||
AutoMooncake(; config=nothing), | ||
(1.0,), | ||
(2.0,), | ||
) | ||
@test_throws MethodError pullback( | ||
DifferentiateWith(f_tup2vec, AutoFiniteDiff()), | ||
AutoMooncake(; config=nothing), | ||
(1.0,), | ||
([2.0],), | ||
) | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.