Skip to content

Commit 8221520

Browse files
committed
Add CUDAEnzyme extension
1 parent e0b64a5 commit 8221520

File tree

4 files changed

+72
-61
lines changed

4 files changed

+72
-61
lines changed

Diff for: Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.9.15"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
9-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
109
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -20,9 +19,11 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
2019

2120
[weakdeps]
2221
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
22+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2323

2424
[extensions]
2525
EnzymeExt = "EnzymeCore"
26+
CUDAEnzymeExt = ["CUDA", "EnzymeCore"]
2627

2728
[compat]
2829
Adapt = "0.4, 1.0, 2.0, 3.0, 4"

Diff for: ext/CUDAEnzymeExt.jl

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
module CUDAEnzymeExt
2+
if isdefined(Base, :get_extension)
3+
using EnzymeCore
4+
using EnzymeCore.EnzymeRules
5+
else
6+
using ..EnzymeCore
7+
using ..EnzymeCore.EnzymeRules
8+
end
9+
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
10+
using CUDA
11+
12+
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
13+
println("Custom rule GPU")
14+
kernel = func.val
15+
f = kernel.f
16+
mi = CUDA.methodinstance(typeof(()->return), Tuple{})
17+
job = CUDA.CompilerJob(mi, CUDA.compiler_config(device()))
18+
19+
ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)
20+
block = first(blocks(iterspace))
21+
ctx = mkcontext(kernel, ndrange, iterspace)
22+
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}
23+
24+
# TODO autodiff_deferred on the func.val
25+
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
26+
27+
tup = Val(ntuple(Val(N)) do i
28+
Base.@_inline_meta
29+
args[i] isa Active
30+
end)
31+
f = make_active_byref(f, tup)
32+
FT = Const{Core.Typeof(f)}
33+
34+
arg_refs = ntuple(Val(N)) do i
35+
Base.@_inline_meta
36+
if args[i] isa Active
37+
Ref(EnzymeCore.make_zero(args[i].val))
38+
else
39+
nothing
40+
end
41+
end
42+
args2 = ntuple(Val(N)) do i
43+
Base.@_inline_meta
44+
if args[i] isa Active
45+
Duplicated(Ref(args[i].val), arg_refs[i])
46+
else
47+
args[i]
48+
end
49+
end
50+
51+
# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
52+
TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
53+
@show TapeType
54+
55+
56+
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
57+
58+
aug_kernel = similar(kernel, aug_fwd)
59+
60+
aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
61+
62+
# TODO the fact that ctxTy is type unstable means this is all type unstable.
63+
# Since custom rules require a fixed return type, explicitly cast to Any, rather
64+
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.
65+
66+
res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
67+
return res
68+
end
69+
70+
end

Diff for: ext/EnzymeExt.jl

-59
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ module EnzymeExt
77
using ..EnzymeCore.EnzymeRules
88
end
99
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
10-
using CUDA
1110

1211
EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
1312

@@ -112,64 +111,6 @@ module EnzymeExt
112111
return res
113112
end
114113

115-
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
116-
println("Custom rule GPU")
117-
kernel = func.val
118-
f = kernel.f
119-
mi = CUDA.methodinstance(typeof(()->return), Tuple{})
120-
job = CUDA.CompilerJob(mi, CUDA.compiler_config(device()))
121-
122-
ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)
123-
block = first(blocks(iterspace))
124-
ctx = mkcontext(kernel, ndrange, iterspace)
125-
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}
126-
127-
# TODO autodiff_deferred on the func.val
128-
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
129-
130-
tup = Val(ntuple(Val(N)) do i
131-
Base.@_inline_meta
132-
args[i] isa Active
133-
end)
134-
f = make_active_byref(f, tup)
135-
FT = Const{Core.Typeof(f)}
136-
137-
arg_refs = ntuple(Val(N)) do i
138-
Base.@_inline_meta
139-
if args[i] isa Active
140-
Ref(EnzymeCore.make_zero(args[i].val))
141-
else
142-
nothing
143-
end
144-
end
145-
args2 = ntuple(Val(N)) do i
146-
Base.@_inline_meta
147-
if args[i] isa Active
148-
Duplicated(Ref(args[i].val), arg_refs[i])
149-
else
150-
args[i]
151-
end
152-
end
153-
154-
# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
155-
TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
156-
@show TapeType
157-
158-
159-
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
160-
161-
aug_kernel = similar(kernel, aug_fwd)
162-
163-
aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
164-
165-
# TODO the fact that ctxTy is type unstable means this is all type unstable.
166-
# Since custom rules require a fixed return type, explicitly cast to Any, rather
167-
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.
168-
169-
res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
170-
return res
171-
end
172-
173114
function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
174115
subtape, arg_refs = tape
175116

Diff for: test/reverse_gpu.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Test
2-
using Enzyme_jll
32
using Enzyme
43
using KernelAbstractions
54
using CUDA

0 commit comments

Comments
 (0)