Skip to content

Commit 5b8a551

Browse files
wsmosesvchuravy
andauthored
Correctly adapt to ABI changes to EnzymeCore (#476)
Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
1 parent d1bee3f commit 5b8a551

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
2020
[compat]
2121
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
2222
Atomix = "0.1"
23-
EnzymeCore = "0.6.4, 0.7"
23+
EnzymeCore = "0.7"
2424
InteractiveUtils = "1.6"
2525
LinearAlgebra = "1.6"
2626
MacroTools = "0.5"

ext/EnzymeExt.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@ module EnzymeExt
1111
EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
1212

1313
function fwd(ctx, f, args...)
14-
EnzymeCore.autodiff_deferred(Forward, Const(f), Const, Const(ctx), args...)
14+
EnzymeCore.autodiff_deferred(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
1515
return nothing
1616
end
1717

1818
function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
19-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
19+
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
20+
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
2021
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
2122
return nothing
2223
end
2324

2425
function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
25-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
26+
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
27+
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
2628
tp = subtape[__groupindex(ctx)]
2729
reverse(Const(f), Const(ctx), args..., tp)
2830
return nothing
@@ -92,7 +94,7 @@ module EnzymeExt
9294
end
9395

9496
# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
95-
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
97+
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const{Nothing}, Const{ctxTy}, map(Core.Typeof, args2)...)
9698

9799

98100
subtape = Array{TapeType}(undef, size(blocks(iterspace)))

0 commit comments

Comments
 (0)