-
Notifications
You must be signed in to change notification settings - Fork 25
feat: backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #768 +/- ##
==========================================
- Coverage 97.88% 97.87% -0.01%
==========================================
Files 128 129 +1
Lines 7694 7754 +60
==========================================
+ Hits 7531 7589 +58
- Misses 163 165 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
removed the code that piggybacks off the Chainrules wrapper. This is specifically now a Mooncake generic rule which handles backend switching. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this first draft!
I think there are some changes necessary, and most importantly you need to test it, first locally and then during CI (try not to run CI before having tested your changes locally, the process is very expensive since it tests a dozen different backends for like half an hour each).
For the testing, start with manual tests, and then once your code works you can add AutoMooncake()
to this line
...tionInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
...tionInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
sorry i got preoccupied with some other work, hence the incomplete PR. This would be on route now. |
Please keep in mind that every commit costs around 6 hours of CI budget. I suggest you make as many modifications as possible locally and add tests first before pushing |
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we're getting closer!
Unfortunately I think my existing tests are not enough to capture everything that can go wrong in a Mooncake rule. Perhaps the Mooncake test utilities should be brought in, or more sophisticated tests should be written.
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work, @AstitvaAggarwal. We are getting there. I have just a few more quality-related comments on the documents, type stability, etc.
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @AstitvaAggarwal for the nice work!
EDIT: I am happy with this PR, but still need @gdalle's blessing before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good work @AstitvaAggarwal, thank you so much for taking the time! My main remaining questions are about:
- support for tuples: do we keep it, do we remove it? It seems a bit arbitrary to have just tuples in addition to the "officially" supported
Number
andAbstractArray
, especially if these can only be tuples of numbers - testing code, which should not touch on so many Mooncake internals
To lighten the load on you, I'll make the suggested modifications myself, and you can tell me what you think after! Thanks again for your patience.
# nested vectors (eg. [[1.0]]), Tuples (eg. ((1.0,),)) or similar (eg. [(1.0,)]) primal types are not supported by DI yet ! | ||
# This is because basis construction (DI.basis) does not have overloads for these types. | ||
# For details, refer commented out test cases to see where the pullback creation fails. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that we're testing DifferentiateWith(f, substitute_backend)
with substitute_backend = AutoFiniteDiff()
, aka a forward-mode backend. I think it should work with DifferentiateWith(f, AutoEnzyme())
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved by removing tuples for the time being
@@ -0,0 +1,268 @@ | |||
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it a bit weird to have this union of Number, AbstractArray
(two types which are theoretically supported for x
inputs in DI) and then just Tuple
(which is not officially part of the supported inputs). Why not also NamedTuple
for instance? Is it better if we just say Any
? Or restrict to Number
and AbstractArray
for the time being?
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
…xt/differentiate_with.jl
Define
Mooncake.rrule!!
forDI.DifferentiateWith
.