Skip to content

Commit b82dfb4

Browse files
committed
Adding extension fix and allocate
1 parent 9dd121d commit b82dfb4

File tree

7 files changed

+55
-45
lines changed

7 files changed

+55
-45
lines changed

ext/CUDAEnzymeExt.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ module CUDAEnzymeExt
99
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
1010
using CUDA
1111

12+
include("enzyme_utils.jl")
13+
1214
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
1315
println("Custom rule GPU")
1416
kernel = func.val
@@ -52,8 +54,8 @@ module CUDAEnzymeExt
5254
TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
5355
@show TapeType
5456

55-
56-
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
57+
subtape = allocate(CUDABackend(), TapeType, size(blocks(iterspace)))
58+
# subtape = Array{TapeType}(undef, size(blocks(iterspace)))
5759

5860
aug_kernel = similar(kernel, aug_fwd)
5961

@@ -67,4 +69,4 @@ module CUDAEnzymeExt
6769
return res
6870
end
6971

70-
end
72+
end

ext/EnzymeExt.jl

+2-36
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,9 @@ module EnzymeExt
88
end
99
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
1010

11-
EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
12-
13-
function fwd(ctx, f, args...)
14-
EnzymeCore.autodiff_deferred(Forward, Const(f), Const, Const(ctx), args...)
15-
return nothing
16-
end
11+
include("enzyme_utils.jl")
1712

18-
function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
19-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
20-
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
21-
return nothing
22-
end
23-
24-
function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
25-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
26-
tp = subtape[__groupindex(ctx)]
27-
reverse(Const(f), Const(ctx), args..., tp)
28-
return nothing
29-
end
13+
EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
3014

3115
function EnzymeRules.forward(func::Const{<:Kernel}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing)
3216
kernel = func.val
@@ -36,24 +20,6 @@ module EnzymeExt
3620
fwd_kernel(f, args...; ndrange, workgroupsize)
3721
end
3822

39-
@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
40-
if !any(ActiveTys)
41-
return f
42-
end
43-
function inact(ctx, args2::Vararg{Any, N}) where N
44-
args3 = ntuple(Val(N)) do i
45-
Base.@_inline_meta
46-
if ActiveTys[i]
47-
args2[i][]
48-
else
49-
args2[i]
50-
end
51-
end
52-
f(ctx, args3...)
53-
end
54-
return inact
55-
end
56-
5723
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
5824
println("Custom rule CPU")
5925
kernel = func.val

ext/enzyme_utils.jl

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
function fwd(ctx, f, args...)
2+
EnzymeCore.autodiff_deferred(Forward, Const(f), Const, Const(ctx), args...)
3+
return nothing
4+
end
5+
6+
function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
7+
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
8+
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
9+
return nothing
10+
end
11+
12+
function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
13+
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
14+
tp = subtape[__groupindex(ctx)]
15+
reverse(Const(f), Const(ctx), args..., tp)
16+
return nothing
17+
end
18+
19+
@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
20+
if !any(ActiveTys)
21+
return f
22+
end
23+
function inact(ctx, args2::Vararg{Any, N}) where N
24+
args3 = ntuple(Val(N)) do i
25+
Base.@_inline_meta
26+
if ActiveTys[i]
27+
args2[i][]
28+
else
29+
args2[i]
30+
end
31+
end
32+
f(ctx, args3...)
33+
end
34+
return inact
35+
end

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
45
Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
56
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7+
# KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
68
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
79
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
810
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

test/extensions/enzyme.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ end
1010
function square_caller(A, backend)
1111
kernel = square!(backend)
1212
kernel(A, ndrange=size(A))
13-
KernelAbstractions.synchronize(backend)
1413
end
1514

1615

@@ -22,7 +21,6 @@ end
2221
function mul_caller(A, B, backend)
2322
kernel = mul!(backend)
2423
kernel(A, B, ndrange=size(A))
25-
KernelAbstractions.synchronize(backend)
2624
end
2725

2826
function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
@@ -36,13 +34,15 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
3634
dA .= 1
3735

3836
Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend()))
37+
KernelAbstractions.synchronize(backend())
3938
@test all(dA .≈ (2:2:128))
4039

4140

4241
A .= (1:1:64)
4342
dA .= 1
4443

4544
_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1]
45+
KernelAbstractions.synchronize(backend())
4646

4747
@test all(dA .≈ 1.2)
4848
@test dB sum(1:1:64)
@@ -52,6 +52,7 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
5252
dA .= 1
5353

5454
Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA), Const(backend()))
55+
KernelAbstractions.synchronize(backend())
5556
@test all(dA .≈ 2:2:128)
5657

5758
end

test/reverse_gpu.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ end
1111
function square_caller(A, backend)
1212
kernel = square!(backend)
1313
kernel(A, ndrange=size(A))
14-
KernelAbstractions.synchronize(backend)
1514
end
1615

1716

@@ -23,7 +22,6 @@ end
2322
function mul_caller(A, B, backend)
2423
kernel = mul!(backend)
2524
kernel(A, B, ndrange=size(A))
26-
KernelAbstractions.synchronize(backend)
2725
end
2826

2927
function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
@@ -37,13 +35,15 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
3735
dA .= 1
3836

3937
Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend()))
38+
KernelAbstractions.synchronize(backend())
4039
@test all(dA .≈ (2:2:128))
4140

4241

4342
A .= (1:1:64)
4443
dA .= 1
4544

4645
_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1]
46+
KernelAbstractions.synchronize(backend())
4747

4848
@test all(dA .≈ 1.2)
4949
@test dB sum(1:1:64)
@@ -53,11 +53,12 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
5353
dA .= 1
5454

5555
Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA), Const(backend()))
56+
KernelAbstractions.synchronize(backend())
5657
@test all(dA .≈ 2:2:128)
5758

5859
end
5960
end
6061

6162
# enzyme_testsuite(CPU, Array, true)
62-
enzyme_testsuite(CUDABackend, CuArray, false)
63-
# enzyme_testsuite(CUDABackend, CuArray, true)
63+
# enzyme_testsuite(CUDABackend, CuArray, false)
64+
enzyme_testsuite(CUDABackend, CuArray, true)

test/runtests.jl

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using CUDA
12
using KernelAbstractions
23
using Test
34

@@ -74,5 +75,7 @@ include("extensions/enzyme.jl")
7475
@static if VERSION >= v"1.7.0"
7576
@testset "Enzyme" begin
7677
enzyme_testsuite(CPU, Array)
78+
enzyme_testsuite(CUDABackend, CuArray, false)
79+
# enzyme_testsuite(CUDABackend, CuArray, true)
7780
end
7881
end

0 commit comments

Comments
 (0)