You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
59
+
println("Custom rule CPU")
58
60
kernel = func.val
59
61
f = kernel.f
60
62
@@ -93,6 +95,65 @@ module EnzymeExt
93
95
94
96
#TODO in KA backends like CUDAKernels, etc have a version with a parent job type
#TODO the fact that ctxTy is type unstable means this is all type unstable.
108
+
# Since custom rules require a fixed return type, explicitly cast to Any, rather
109
+
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.
110
+
111
+
res =AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
112
+
return res
113
+
end
114
+
115
+
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
116
+
println("Custom rule GPU")
117
+
kernel = func.val
118
+
f = kernel.f
119
+
mi = CUDA.methodinstance(typeof(()->return), Tuple{})
0 commit comments