@@ -11,18 +11,20 @@ module EnzymeExt
11
11
EnzymeRules. inactive (:: Type{StaticSize} , x... ) = nothing
12
12
13
13
function fwd (ctx, f, args... )
14
- EnzymeCore. autodiff_deferred (Forward, Const (f), Const, Const (ctx), args... )
14
+ EnzymeCore. autodiff_deferred (Forward, Const (f), Const{Nothing} , Const (ctx), args... )
15
15
return nothing
16
16
end
17
17
18
18
function aug_fwd (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
19
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
19
+ TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
20
+ forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
20
21
subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
21
22
return nothing
22
23
end
23
24
24
25
function rev (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
25
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
26
+ TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
27
+ forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
26
28
tp = subtape[__groupindex (ctx)]
27
29
reverse (Const (f), Const (ctx), args... , tp)
28
30
return nothing
@@ -92,7 +94,7 @@ module EnzymeExt
92
94
end
93
95
94
96
# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
95
- TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map (Core. Typeof, args2)... )
97
+ TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), FT, Const{Nothing} , Const{ctxTy}, map (Core. Typeof, args2)... )
96
98
97
99
98
100
subtape = Array {TapeType} (undef, size (blocks (iterspace)))
0 commit comments