@@ -6,7 +6,11 @@ module EnzymeExt
6
6
using .. EnzymeCore
7
7
using .. EnzymeCore. EnzymeRules
8
8
end
9
- import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU
9
+ using KernelAbstractions
10
+ import KernelAbstractions: Kernel, StaticSize, launch_config, allocate,
11
+ blocks, mkcontext, CompilerMetadata, CPU, GPU, argconvert,
12
+ supports_enzyme, __fake_compiler_job, backend,
13
+ __index_Group_Cartesian, __index_Global_Linear
10
14
11
15
EnzymeRules. inactive (:: Type{StaticSize} , x... ) = nothing
12
16
@@ -15,55 +19,188 @@ module EnzymeExt
15
19
return nothing
16
20
end
17
21
18
- function aug_fwd (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
19
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
20
- subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
22
+ function EnzymeRules. forward (func:: Const{<:Kernel} , :: Type{Const{Nothing}} , args... ; ndrange= nothing , workgroupsize= nothing )
23
+ kernel = func. val
24
+ f = kernel. f
25
+ fwd_kernel = similar (kernel, fwd)
26
+
27
+ fwd_kernel (f, args... ; ndrange, workgroupsize)
28
+ end
29
+
30
+ function _enzyme_mkcontext (kernel:: Kernel{CPU} , ndrange, iterspace, dynamic)
31
+ block = first (blocks (iterspace))
32
+ return mkcontext (kernel, block, ndrange, iterspace, dynamic)
33
+ end
34
+
35
+ function _enzyme_mkcontext (kernel:: Kernel{<:GPU} , ndrange, iterspace, dynamic)
36
+ return mkcontext (kernel, ndrange, iterspace)
37
+ end
38
+
39
+ function _augmented_return (:: Kernel{CPU} , subtape, arg_refs, tape_type)
40
+ return AugmentedReturn {Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}} (
41
+ nothing , nothing , (subtape, arg_refs, tape_type)
42
+ )
43
+ end
44
+
45
+ function _augmented_return (:: Kernel{<:GPU} , subtape, arg_refs, tape_type)
46
+ return AugmentedReturn {Nothing, Nothing, Any} (
47
+ nothing , nothing , (subtape, arg_refs, tape_type)
48
+ )
49
+ end
50
+
51
+ function _create_tape_kernel (
52
+ kernel:: Kernel{CPU} , ModifiedBetween,
53
+ FT, ctxTy, ndrange, iterspace, args2...
54
+ )
55
+ TapeType = EnzymeCore. tape_type (
56
+ ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween),
57
+ FT, Const, Const{ctxTy}, map (Core. Typeof, args2)...
58
+ )
59
+ subtape = Array {TapeType} (undef, size (blocks (iterspace)))
60
+ aug_kernel = similar (kernel, cpu_aug_fwd)
61
+ return TapeType, subtape, aug_kernel
62
+ end
63
+
64
+ function _create_tape_kernel (
65
+ kernel:: Kernel{<:GPU} , ModifiedBetween,
66
+ FT, ctxTy, ndrange, iterspace, args2...
67
+ )
68
+ # For peeking at the TapeType we need to first construct a correct compilation job
69
+ # this requires the use of the device side representation of arguments.
70
+ # So we convert the arguments here, this is a bit wasteful since the `aug_kernel` call
71
+ # will later do the same.
72
+ dev_args2 = ((argconvert (kernel, a) for a in args2). .. ,)
73
+ dev_TT = map (Core. Typeof, dev_args2)
74
+
75
+ job = __fake_compiler_job (backend (kernel))
76
+ TapeType = EnzymeCore. tape_type (
77
+ job, ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween),
78
+ FT, Const, Const{ctxTy}, dev_TT...
79
+ )
80
+
81
+ # Allocate per thread
82
+ subtape = allocate (backend (kernel), TapeType, prod (ndrange))
83
+
84
+ aug_kernel = similar (kernel, gpu_aug_fwd)
85
+ return TapeType, subtape, aug_kernel
86
+ end
87
+
88
+ _create_rev_kernel (kernel:: Kernel{CPU} ) = similar (kernel, cpu_rev)
89
+ _create_rev_kernel (kernel:: Kernel{<:GPU} ) = similar (kernel, gpu_rev)
90
+
91
+ function cpu_aug_fwd (
92
+ ctx, f:: FT , :: Val{ModifiedBetween} , subtape, :: Val{TapeType} , args...
93
+ ) where {ModifiedBetween, FT, TapeType}
94
+ # A2 = Const{Nothing} -- since f->Nothing
95
+ forward, _ = EnzymeCore. autodiff_deferred_thunk (
96
+ ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType,
97
+ Const{Core. Typeof (f)}, Const, Const{Nothing},
98
+ Const{Core. Typeof (ctx)}, map (Core. Typeof, args)...
99
+ )
100
+
101
+ # On the CPU: F is a per block function
102
+ # On the CPU: subtape::Vector{Vector}
103
+ I = __index_Group_Cartesian (ctx, #= fake=# CartesianIndex (1 ,1 ))
104
+ subtape[I] = forward (Const (f), Const (ctx), args... )[1 ]
21
105
return nothing
22
106
end
23
107
24
- function rev (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
25
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
26
- tp = subtape[__groupindex (ctx)]
108
+ function cpu_rev (
109
+ ctx, f:: FT , :: Val{ModifiedBetween} , subtape, :: Val{TapeType} , args...
110
+ ) where {ModifiedBetween, FT, TapeType}
111
+ _, reverse = EnzymeCore. autodiff_deferred_thunk (
112
+ ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType,
113
+ Const{Core. Typeof (f)}, Const, Const{Nothing},
114
+ Const{Core. Typeof (ctx)}, map (Core. Typeof, args)...
115
+ )
116
+ I = __index_Group_Cartesian (ctx, #= fake=# CartesianIndex (1 ,1 ))
117
+ tp = subtape[I]
27
118
reverse (Const (f), Const (ctx), args... , tp)
28
119
return nothing
29
120
end
30
121
31
- function EnzymeRules. forward (func:: Const{<:Kernel} , :: Type{Const{Nothing}} , args... ; ndrange= nothing , workgroupsize= nothing )
122
+ function EnzymeRules. reverse (config:: Config , func:: Const{<:Kernel} , :: Type{<:EnzymeCore.Annotation} , tape, args:: Vararg{Any, N} ; ndrange= nothing , workgroupsize= nothing ) where N
123
+ subtape, arg_refs, tape_type = tape
124
+
125
+ args2 = ntuple (Val (N)) do i
126
+ Base. @_inline_meta
127
+ if args[i] isa Active
128
+ Duplicated (Ref (args[i]. val), arg_refs[i])
129
+ else
130
+ args[i]
131
+ end
132
+ end
133
+
32
134
kernel = func. val
33
135
f = kernel. f
34
- fwd_kernel = similar (kernel, fwd)
35
136
36
- fwd_kernel (f, args... ; ndrange, workgroupsize)
37
- end
137
+ tup = Val (ntuple (Val (N)) do i
138
+ Base. @_inline_meta
139
+ args[i] isa Active
140
+ end )
141
+ f = make_active_byref (f, tup)
38
142
39
- @inline function make_active_byref (f:: F , :: Val{ActiveTys} ) where {F, ActiveTys}
40
- if ! any (ActiveTys)
41
- return f
42
- end
43
- function inact (ctx, args2:: Vararg{Any, N} ) where N
44
- args3 = ntuple (Val (N)) do i
45
- Base. @_inline_meta
46
- if ActiveTys[i]
47
- args2[i][]
48
- else
49
- args2[i]
50
- end
143
+ ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
144
+
145
+ rev_kernel = _create_rev_kernel (kernel)
146
+ rev_kernel (f, ModifiedBetween, subtape, Val (tape_type), args2... ; ndrange, workgroupsize)
147
+ res = ntuple (Val (N)) do i
148
+ Base. @_inline_meta
149
+ if args[i] isa Active
150
+ arg_refs[i][]
151
+ else
152
+ nothing
51
153
end
52
- f (ctx, args3... )
53
154
end
54
- return inact
155
+ return res
55
156
end
56
157
57
- function EnzymeRules. augmented_primal (config:: Config , func:: Const{<:Kernel{CPU}} , :: Type{Const{Nothing}} , args:: Vararg{Any, N} ; ndrange= nothing , workgroupsize= nothing ) where N
158
+ # GPU support
159
+ function gpu_aug_fwd (
160
+ ctx, f:: FT , :: Val{ModifiedBetween} , subtape, :: Val{TapeType} , args...
161
+ ) where {ModifiedBetween, FT, TapeType}
162
+ # A2 = Const{Nothing} -- since f->Nothing
163
+ forward, _ = EnzymeCore. autodiff_deferred_thunk (
164
+ ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType,
165
+ Const{Core. Typeof (f)}, Const, Const{Nothing},
166
+ Const{Core. Typeof (ctx)}, map (Core. Typeof, args)...
167
+ )
168
+
169
+ # On the GPU: F is a per thread function
170
+ # On the GPU: subtape::Vector
171
+ I = __index_Global_Linear (ctx)
172
+ subtape[I] = forward (Const (f), Const (ctx), args... )[1 ]
173
+ return nothing
174
+ end
175
+
176
+ function gpu_rev (
177
+ ctx, f:: FT , :: Val{ModifiedBetween} , subtape, :: Val{TapeType} , args...
178
+ ) where {ModifiedBetween, FT, TapeType}
179
+ # XXX : TapeType and A2 as args to autodiff_deferred_thunk
180
+ _, reverse = EnzymeCore. autodiff_deferred_thunk (
181
+ ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType,
182
+ Const{Core. Typeof (f)}, Const, Const{Nothing},
183
+ Const{Core. Typeof (ctx)}, map (Core. Typeof, args)...
184
+ )
185
+ I = __index_Global_Linear (ctx)
186
+ tp = subtape[I]
187
+ reverse (Const (f), Const (ctx), args... , tp)
188
+ return nothing
189
+ end
190
+
191
+ function EnzymeRules. augmented_primal (
192
+ config:: Config , func:: Const{<:Kernel} ,
193
+ :: Type{Const{Nothing}} , args:: Vararg{Any, N} ; ndrange= nothing , workgroupsize= nothing
194
+ ) where N
58
195
kernel = func. val
196
+ if ! supports_enzyme (backend (kernel))
197
+ error (" KernelAbstractions backend does not support Enzyme" )
198
+ end
59
199
f = kernel. f
60
200
61
201
ndrange, workgroupsize, iterspace, dynamic = launch_config (kernel, ndrange, workgroupsize)
62
- block = first (blocks (iterspace))
63
-
64
- ctx = mkcontext (kernel, block, ndrange, iterspace, dynamic)
202
+ ctx = _enzyme_mkcontext (kernel, ndrange, iterspace, dynamic)
65
203
ctxTy = Core. Typeof (ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}
66
-
67
204
# TODO autodiff_deferred on the func.val
68
205
ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
69
206
@@ -91,56 +228,34 @@ module EnzymeExt
91
228
end
92
229
end
93
230
94
- # TODO in KA backends like CUDAKernels, etc have a version with a parent job type
95
- TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map (Core. Typeof, args2)... )
96
-
97
-
98
- subtape = Array {TapeType} (undef, size (blocks (iterspace)))
99
-
100
- aug_kernel = similar (kernel, aug_fwd)
101
-
102
- aug_kernel (f, ModifiedBetween, subtape, args2... ; ndrange, workgroupsize)
231
+ TapeType, subtape, aug_kernel = _create_tape_kernel (
232
+ kernel, ModifiedBetween, FT, ctxTy, ndrange, iterspace, args2...
233
+ )
234
+ aug_kernel (f, ModifiedBetween, subtape, Val (TapeType), args2... ; ndrange, workgroupsize)
235
+ KernelAbstractions. synchronize (backend (kernel))
103
236
104
237
# TODO the fact that ctxTy is type unstable means this is all type unstable.
105
238
# Since custom rules require a fixed return type, explicitly cast to Any, rather
106
239
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.
107
-
108
- res = AugmentedReturn {Nothing, Nothing, Tuple{Array, typeof(arg_refs)}} (nothing , nothing , (subtape, arg_refs))
109
- return res
240
+ return _augmented_return (kernel, subtape, arg_refs, TapeType)
110
241
end
111
242
112
- function EnzymeRules. reverse (config:: Config , func:: Const{<:Kernel} , :: Type{<:EnzymeCore.Annotation} , tape, args:: Vararg{Any, N} ; ndrange= nothing , workgroupsize= nothing ) where N
113
- subtape, arg_refs = tape
114
-
115
- args2 = ntuple (Val (N)) do i
116
- Base. @_inline_meta
117
- if args[i] isa Active
118
- Duplicated (Ref (args[i]. val), arg_refs[i])
119
- else
120
- args[i]
121
- end
122
- end
123
-
124
- kernel = func. val
125
- f = kernel. f
126
-
127
- tup = Val (ntuple (Val (N)) do i
128
- Base. @_inline_meta
129
- args[i] isa Active
130
- end )
131
- f = make_active_byref (f, tup)
132
-
133
- ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
134
-
135
- rev_kernel = similar (func. val, rev)
136
- rev_kernel (f, ModifiedBetween, subtape, args2... ; ndrange, workgroupsize)
137
- return ntuple (Val (N)) do i
243
+ @inline function make_active_byref (f:: F , :: Val{ActiveTys} ) where {F, ActiveTys}
244
+ if ! any (ActiveTys)
245
+ return f
246
+ end
247
+ function inact (ctx, args2:: Vararg{Any, N} ) where N
248
+ args3 = ntuple (Val (N)) do i
138
249
Base. @_inline_meta
139
- if args [i] isa Active
140
- arg_refs [i][]
250
+ if ActiveTys [i]
251
+ args2 [i][]
141
252
else
142
- nothing
253
+ args2[i]
143
254
end
144
255
end
256
+ f (ctx, args3... )
145
257
end
258
+ return inact
259
+ end
260
+
146
261
end
0 commit comments