Skip to content

Commit 6f10796

Browse files
committed
Correctly adapt to ABI changes to EnzymeCore
1 parent d1bee3f commit 6f10796

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
2020
[compat]
2121
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
2222
Atomix = "0.1"
23-
EnzymeCore = "0.6.4, 0.7"
23+
EnzymeCore = "0.7"
2424
InteractiveUtils = "1.6"
2525
LinearAlgebra = "1.6"
2626
MacroTools = "0.5"

ext/EnzymeExt.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ module EnzymeExt
1616
end
1717

1818
function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
19-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
19+
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
20+
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
2021
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
2122
return nothing
2223
end
2324

2425
function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
25-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
26+
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
27+
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
2628
tp = subtape[__groupindex(ctx)]
2729
reverse(Const(f), Const(ctx), args..., tp)
2830
return nothing

0 commit comments

Comments
 (0)