@@ -16,13 +16,15 @@ module EnzymeExt
16
16
end
17
17
18
18
function aug_fwd (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
19
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
19
+ TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
20
+ forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
20
21
subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
21
22
return nothing
22
23
end
23
24
24
25
function rev (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
25
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
26
+ TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
27
+ forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
26
28
tp = subtape[__groupindex (ctx)]
27
29
reverse (Const (f), Const (ctx), args... , tp)
28
30
return nothing
0 commit comments