@@ -17,14 +17,16 @@ module EnzymeExt
17
17
18
18
function aug_fwd (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
19
19
TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
20
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
20
+ forward, reverse = EnzymeCore. autodiff_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
21
+ # forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
21
22
subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
22
23
return nothing
23
24
end
24
25
25
26
function rev (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
26
27
TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
27
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
28
+ forward, reverse = EnzymeCore. autodiff_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
29
+ # forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
28
30
tp = subtape[__groupindex (ctx)]
29
31
reverse (Const (f), Const (ctx), args... , tp)
30
32
return nothing
0 commit comments