Skip to content

Commit 07d52b7

Browse files
committed
Compiler crashes
1 parent b194b6d commit 07d52b7

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

Diff for: ext/EnzymeExt.jl

+62-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ module EnzymeExt
66
using ..EnzymeCore
77
using ..EnzymeCore.EnzymeRules
88
end
9-
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU
9+
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
10+
using CUDA
1011

1112
EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
1213

@@ -55,6 +56,7 @@ module EnzymeExt
5556
end
5657

5758
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
59+
println("Custom rule CPU")
5860
kernel = func.val
5961
f = kernel.f
6062

@@ -93,6 +95,65 @@ module EnzymeExt
9395

9496
# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
9597
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
98+
@show TapeType
99+
100+
101+
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
102+
103+
aug_kernel = similar(kernel, aug_fwd)
104+
105+
aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
106+
107+
# TODO the fact that ctxTy is type unstable means this is all type unstable.
108+
# Since custom rules require a fixed return type, explicitly cast to Any, rather
109+
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.
110+
111+
res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
112+
return res
113+
end
114+
115+
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
116+
println("Custom rule GPU")
117+
kernel = func.val
118+
f = kernel.f
119+
mi = CUDA.methodinstance(typeof(()->return), Tuple{})
120+
job = CUDA.CompilerJob(mi, CUDA.compiler_config(device()))
121+
122+
ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)
123+
block = first(blocks(iterspace))
124+
ctx = mkcontext(kernel, ndrange, iterspace)
125+
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}
126+
127+
# TODO autodiff_deferred on the func.val
128+
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
129+
130+
tup = Val(ntuple(Val(N)) do i
131+
Base.@_inline_meta
132+
args[i] isa Active
133+
end)
134+
f = make_active_byref(f, tup)
135+
FT = Const{Core.Typeof(f)}
136+
137+
arg_refs = ntuple(Val(N)) do i
138+
Base.@_inline_meta
139+
if args[i] isa Active
140+
Ref(EnzymeCore.make_zero(args[i].val))
141+
else
142+
nothing
143+
end
144+
end
145+
args2 = ntuple(Val(N)) do i
146+
Base.@_inline_meta
147+
if args[i] isa Active
148+
Duplicated(Ref(args[i].val), arg_refs[i])
149+
else
150+
args[i]
151+
end
152+
end
153+
154+
# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
155+
TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
156+
@show TapeType
96157

97158

98159
subtape = Array{TapeType}(undef, size(blocks(iterspace)))

Diff for: reverse_gpu.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
5959
end
6060
end
6161

62-
enzyme_testsuite(CPU, Array, true)
62+
# enzyme_testsuite(CPU, Array, true)
6363
# enzyme_testsuite(CUDABackend, CuArray, false)
64-
enzyme_testsuite(CUDABackend, CuArray, true)
64+
enzyme_testsuite(CUDABackend, CuArray, true)

0 commit comments

Comments
 (0)