Skip to content

Commit 8d4e7f8

Browse files
committed
Add Enzyme GPU support
Still missing synchronize
1 parent 6aee730 commit 8d4e7f8

File tree

8 files changed

+210
-77
lines changed

8 files changed

+210
-77
lines changed

Diff for: .buildkite/pipeline.yml

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ steps:
3434
version:
3535
- "1.8"
3636
- "1.9"
37+
- "1.10"
3738
plugins:
3839
- JuliaCI/julia#v1:
3940
version: "{{matrix.version}}"

Diff for: ext/EnzymeExt.jl

+186-71
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ module EnzymeExt
66
using ..EnzymeCore
77
using ..EnzymeCore.EnzymeRules
88
end
9-
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU
9+
using KernelAbstractions
10+
import KernelAbstractions: Kernel, StaticSize, launch_config, allocate,
11+
blocks, mkcontext, CompilerMetadata, CPU, GPU, argconvert,
12+
supports_enzyme, __fake_compiler_job, backend,
13+
__index_Group_Cartesian, __index_Global_Linear
1014

1115
EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
1216

@@ -15,55 +19,188 @@ module EnzymeExt
1519
return nothing
1620
end
1721

18-
function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
19-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
20-
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
22+
function EnzymeRules.forward(func::Const{<:Kernel}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing)
23+
kernel = func.val
24+
f = kernel.f
25+
fwd_kernel = similar(kernel, fwd)
26+
27+
fwd_kernel(f, args...; ndrange, workgroupsize)
28+
end
29+
30+
function _enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic)
31+
block = first(blocks(iterspace))
32+
return mkcontext(kernel, block, ndrange, iterspace, dynamic)
33+
end
34+
35+
function _enzyme_mkcontext(kernel::Kernel{<:GPU}, ndrange, iterspace, dynamic)
36+
return mkcontext(kernel, ndrange, iterspace)
37+
end
38+
39+
function _augmented_return(::Kernel{CPU}, subtape, arg_refs, tape_type)
40+
return AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}}(
41+
nothing, nothing, (subtape, arg_refs, tape_type)
42+
)
43+
end
44+
45+
function _augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type)
46+
return AugmentedReturn{Nothing, Nothing, Any}(
47+
nothing, nothing, (subtape, arg_refs, tape_type)
48+
)
49+
end
50+
51+
function _create_tape_kernel(
52+
kernel::Kernel{CPU}, ModifiedBetween,
53+
FT, ctxTy, ndrange, iterspace, args2...
54+
)
55+
TapeType = EnzymeCore.tape_type(
56+
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
57+
FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...
58+
)
59+
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
60+
aug_kernel = similar(kernel, cpu_aug_fwd)
61+
return TapeType, subtape, aug_kernel
62+
end
63+
64+
function _create_tape_kernel(
65+
kernel::Kernel{<:GPU}, ModifiedBetween,
66+
FT, ctxTy, ndrange, iterspace, args2...
67+
)
68+
# For peeking at the TapeType we need to first construct a correct compilation job
69+
# this requires the use of the device side representation of arguments.
70+
# So we convert the arguments here, this is a bit wasteful since the `aug_kernel` call
71+
# will later do the same.
72+
dev_args2 = ((argconvert(kernel, a) for a in args2)...,)
73+
dev_TT = map(Core.Typeof, dev_args2)
74+
75+
job = __fake_compiler_job(backend(kernel))
76+
TapeType = EnzymeCore.tape_type(
77+
job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
78+
FT, Const, Const{ctxTy}, dev_TT...
79+
)
80+
81+
# Allocate per thread
82+
subtape = allocate(backend(kernel), TapeType, prod(ndrange))
83+
84+
aug_kernel = similar(kernel, gpu_aug_fwd)
85+
return TapeType, subtape, aug_kernel
86+
end
87+
88+
_create_rev_kernel(kernel::Kernel{CPU}) = similar(kernel, cpu_rev)
89+
_create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)
90+
91+
function cpu_aug_fwd(
92+
ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args...
93+
) where {ModifiedBetween, FT, TapeType}
94+
# A2 = Const{Nothing} -- since f->Nothing
95+
forward, _ = EnzymeCore.autodiff_deferred_thunk(
96+
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType,
97+
Const{Core.Typeof(f)}, Const, Const{Nothing},
98+
Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...
99+
)
100+
101+
# On the CPU: F is a per block function
102+
# On the CPU: subtape::Vector{Vector}
103+
I = __index_Group_Cartesian(ctx, #=fake=#CartesianIndex(1,1))
104+
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
21105
return nothing
22106
end
23107

24-
function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
25-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
26-
tp = subtape[__groupindex(ctx)]
108+
function cpu_rev(
109+
ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args...
110+
) where {ModifiedBetween, FT, TapeType}
111+
_, reverse = EnzymeCore.autodiff_deferred_thunk(
112+
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType,
113+
Const{Core.Typeof(f)}, Const, Const{Nothing},
114+
Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...
115+
)
116+
I = __index_Group_Cartesian(ctx, #=fake=#CartesianIndex(1,1))
117+
tp = subtape[I]
27118
reverse(Const(f), Const(ctx), args..., tp)
28119
return nothing
29120
end
30121

31-
function EnzymeRules.forward(func::Const{<:Kernel}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing)
122+
function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
123+
subtape, arg_refs, tape_type = tape
124+
125+
args2 = ntuple(Val(N)) do i
126+
Base.@_inline_meta
127+
if args[i] isa Active
128+
Duplicated(Ref(args[i].val), arg_refs[i])
129+
else
130+
args[i]
131+
end
132+
end
133+
32134
kernel = func.val
33135
f = kernel.f
34-
fwd_kernel = similar(kernel, fwd)
35136

36-
fwd_kernel(f, args...; ndrange, workgroupsize)
37-
end
137+
tup = Val(ntuple(Val(N)) do i
138+
Base.@_inline_meta
139+
args[i] isa Active
140+
end)
141+
f = make_active_byref(f, tup)
38142

39-
@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
40-
if !any(ActiveTys)
41-
return f
42-
end
43-
function inact(ctx, args2::Vararg{Any, N}) where N
44-
args3 = ntuple(Val(N)) do i
45-
Base.@_inline_meta
46-
if ActiveTys[i]
47-
args2[i][]
48-
else
49-
args2[i]
50-
end
143+
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
144+
145+
rev_kernel = _create_rev_kernel(kernel)
146+
rev_kernel(f, ModifiedBetween, subtape, Val(tape_type), args2...; ndrange, workgroupsize)
147+
res = ntuple(Val(N)) do i
148+
Base.@_inline_meta
149+
if args[i] isa Active
150+
arg_refs[i][]
151+
else
152+
nothing
51153
end
52-
f(ctx, args3...)
53154
end
54-
return inact
155+
return res
55156
end
56157

57-
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
158+
# GPU support
159+
function gpu_aug_fwd(
160+
ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args...
161+
) where {ModifiedBetween, FT, TapeType}
162+
# A2 = Const{Nothing} -- since f->Nothing
163+
forward, _ = EnzymeCore.autodiff_deferred_thunk(
164+
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType,
165+
Const{Core.Typeof(f)}, Const, Const{Nothing},
166+
Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...
167+
)
168+
169+
# On the GPU: F is a per thread function
170+
# On the GPU: subtape::Vector
171+
I = __index_Global_Linear(ctx)
172+
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
173+
return nothing
174+
end
175+
176+
function gpu_rev(
177+
ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args...
178+
) where {ModifiedBetween, FT, TapeType}
179+
# XXX: TapeType and A2 as args to autodiff_deferred_thunk
180+
_, reverse = EnzymeCore.autodiff_deferred_thunk(
181+
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType,
182+
Const{Core.Typeof(f)}, Const, Const{Nothing},
183+
Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...
184+
)
185+
I = __index_Global_Linear(ctx)
186+
tp = subtape[I]
187+
reverse(Const(f), Const(ctx), args..., tp)
188+
return nothing
189+
end
190+
191+
function EnzymeRules.augmented_primal(
192+
config::Config, func::Const{<:Kernel},
193+
::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing
194+
) where N
58195
kernel = func.val
196+
if !supports_enzyme(backend(kernel))
197+
error("KernelAbstractions backend does not support Enzyme")
198+
end
59199
f = kernel.f
60200

61201
ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)
62-
block = first(blocks(iterspace))
63-
64-
ctx = mkcontext(kernel, block, ndrange, iterspace, dynamic)
202+
ctx = _enzyme_mkcontext(kernel, ndrange, iterspace, dynamic)
65203
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}
66-
67204
# TODO autodiff_deferred on the func.val
68205
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
69206

@@ -91,56 +228,34 @@ module EnzymeExt
91228
end
92229
end
93230

94-
# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
95-
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
96-
97-
98-
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
99-
100-
aug_kernel = similar(kernel, aug_fwd)
101-
102-
aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
231+
TapeType, subtape, aug_kernel = _create_tape_kernel(
232+
kernel, ModifiedBetween, FT, ctxTy, ndrange, iterspace, args2...
233+
)
234+
aug_kernel(f, ModifiedBetween, subtape, Val(TapeType), args2...; ndrange, workgroupsize)
235+
KernelAbstractions.synchronize(backend(kernel))
103236

104237
# TODO the fact that ctxTy is type unstable means this is all type unstable.
105238
# Since custom rules require a fixed return type, explicitly cast to Any, rather
106239
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.
107-
108-
res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
109-
return res
240+
return _augmented_return(kernel, subtape, arg_refs, TapeType)
110241
end
111242

112-
function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
113-
subtape, arg_refs = tape
114-
115-
args2 = ntuple(Val(N)) do i
116-
Base.@_inline_meta
117-
if args[i] isa Active
118-
Duplicated(Ref(args[i].val), arg_refs[i])
119-
else
120-
args[i]
121-
end
122-
end
123-
124-
kernel = func.val
125-
f = kernel.f
126-
127-
tup = Val(ntuple(Val(N)) do i
128-
Base.@_inline_meta
129-
args[i] isa Active
130-
end)
131-
f = make_active_byref(f, tup)
132-
133-
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
134-
135-
rev_kernel = similar(func.val, rev)
136-
rev_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
137-
return ntuple(Val(N)) do i
243+
@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
244+
if !any(ActiveTys)
245+
return f
246+
end
247+
function inact(ctx, args2::Vararg{Any, N}) where N
248+
args3 = ntuple(Val(N)) do i
138249
Base.@_inline_meta
139-
if args[i] isa Active
140-
arg_refs[i][]
250+
if ActiveTys[i]
251+
args2[i][]
141252
else
142-
nothing
253+
args2[i]
143254
end
144255
end
256+
f(ctx, args3...)
145257
end
258+
return inact
259+
end
260+
146261
end

Diff for: src/KernelAbstractions.jl

+12
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,18 @@ end
698698
__size(args::Tuple) = Tuple{args...}
699699
__size(i::Int) = Tuple{i}
700700

701+
"""
702+
argconvert(::Kernel, arg)
703+
704+
Convert arguments to the device side representation.
705+
"""
706+
argconvert(k::Kernel{T}, arg) where T =
707+
error("Don't know how to convert arguments for Kernel{$T}")
708+
709+
# Enzyme support
710+
supports_enzyme(::Backend) = false
711+
function __fake_compiler_job end
712+
701713
###
702714
# Extras
703715
# - LoopInfo

Diff for: src/cpu.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,6 @@ end
191191
end
192192

193193
# Argument conversion
194-
KernelAbstractions.argconvert(k::Kernel{CPU}, arg) = arg
194+
argconvert(k::Kernel{CPU}, arg) = arg
195+
196+
supports_enzyme(::CPU) = true

Diff for: src/reflection.jl

-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import InteractiveUtils
22
export @ka_code_typed, @ka_code_llvm
33

4-
argconvert(k::Kernel{T}, arg) where T =
5-
error("Don't know how to convert arguments for Kernel{$T}")
6-
74
using UUIDs
85
const Cthulhu = Base.PkgId(UUID("f68482b8-f384-11e8-15f7-abe071a5a75f"), "Cthulhu")
96

Diff for: test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
45
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

Diff for: test/extensions/enzyme.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ end
1010
function square_caller(A, backend)
1111
kernel = square!(backend)
1212
kernel(A, ndrange=size(A))
13-
KernelAbstractions.synchronize(backend)
1413
end
1514

1615

@@ -22,7 +21,6 @@ end
2221
function mul_caller(A, B, backend)
2322
kernel = mul!(backend)
2423
kernel(A, B, ndrange=size(A))
25-
KernelAbstractions.synchronize(backend)
2624
end
2725

2826
function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
@@ -36,13 +34,15 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
3634
dA .= 1
3735

3836
Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend()))
37+
KernelAbstractions.synchronize(backend())
3938
@test all(dA .≈ (2:2:128))
4039

4140

4241
A .= (1:1:64)
4342
dA .= 1
4443

4544
_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1]
45+
KernelAbstractions.synchronize(backend())
4646

4747
@test all(dA .≈ 1.2)
4848
@test dB sum(1:1:64)
@@ -52,6 +52,7 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
5252
dA .= 1
5353

5454
Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA), Const(backend()))
55+
KernelAbstractions.synchronize(backend())
5556
@test all(dA .≈ 2:2:128)
5657

5758
end

0 commit comments

Comments
 (0)