diff --git a/Project.toml b/Project.toml index 82f8ccc9..13a1253d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "10.1.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/lib/GPUArraysCore/Project.toml b/lib/GPUArraysCore/Project.toml index c842d718..2b160c6d 100644 --- a/lib/GPUArraysCore/Project.toml +++ b/lib/GPUArraysCore/Project.toml @@ -6,6 +6,16 @@ version = "0.1.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +[weakdeps] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[extensions] +EnzymeCoreExt = "EnzymeCore" + [compat] Adapt = "4.0" julia = "1.6" +EnzymeCore = "0.6, 0.7" + +[extras] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" diff --git a/lib/GPUArraysCore/ext/EnzymeCoreExt.jl b/lib/GPUArraysCore/ext/EnzymeCoreExt.jl new file mode 100644 index 00000000..ff02ddb4 --- /dev/null +++ b/lib/GPUArraysCore/ext/EnzymeCoreExt.jl @@ -0,0 +1,27 @@ +# compatibility with EnzymeCore + +module EnzymeCoreExt + +using GPUArraysCore + +if isdefined(Base, :get_extension) + using EnzymeCore + using EnzymeCore.EnzymeRules +else + using ..EnzymeCore + using ..EnzymeCore.EnzymeRules +end + +function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(GPUArraysCore.default_scalar_indexing), args...) + return nothing +end + +function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(GPUArraysCore.assertscalar), args...) + return nothing +end + +function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(GPUArraysCore.allowscalar), args...) + return nothing +end + +end # module diff --git a/lib/JLArrays/Project.toml b/lib/JLArrays/Project.toml index ce8959b7..5dac4485 100644 --- a/lib/JLArrays/Project.toml +++ b/lib/JLArrays/Project.toml @@ -1,15 +1,22 @@ name = "JLArrays" uuid = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" authors = ["Tim Besard "] -version = "0.1.4" +version = "0.1.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + +[extensions] +EnzymeExt = "Enzyme" + [compat] Adapt = "2.0, 3.0, 4.0" +Enzyme = "0.12" GPUArrays = "10" julia = "1.8" -Random = "1" +Random = "1" \ No newline at end of file diff --git a/lib/JLArrays/ext/EnzymeExt.jl b/lib/JLArrays/ext/EnzymeExt.jl new file mode 100644 index 00000000..9789a300 --- /dev/null +++ b/lib/JLArrays/ext/EnzymeExt.jl @@ -0,0 +1,54 @@ +module EnzymeExt + +using JLArrays + +using GPUArrays + +if isdefined(Base, :get_extension) + using Enzyme +else + using ..Enzyme +end + + +# Override default type tree. This is because JLArray stores data as Vector{UInt8}, causing issues for +# type analysis not determining the proper element type (instead determining the memory is of type UInt8). +function Enzyme.typetree_inner(::Type{JLT}, ctx, dl, seen::Enzyme.TypeTreeTable) where {JLT<:JLArray} + if JLT isa UnionAll || JLT isa Union || JLT == Union{} || Base.isabstracttype(JLT) + return Enzyme.TypeTree() + end + + if !Base.isconcretetype(JLT) + return Enzyme.TypeTree(Enzyme.API.DT_Pointer, -1, ctx) + end + + elT = eltype(JLT) + + fieldTypes = [DataRef{Vector{elT}}, Int, Dims{ndims(JLT)}] + + tt = Enzyme.TypeTree() + for f in 1:fieldcount(JLT) + offset = fieldoffset(JLT, f) + subT = fieldTypes[f] + subtree = copy(Enzyme.typetree(subT, ctx, dl, seen)) + + if subT isa UnionAll || subT isa Union || subT == Union{} + # FIXME: Handle union + continue + end + + # Allocated inline so adjust first path + if Enzyme.allocatedinline(subT) + Enzyme.shift!(subtree, dl, 0, sizeof(subT), offset) + else + Enzyme.merge!(subtree, Enzyme.TypeTree(Enzyme.API.DT_Pointer, ctx)) + Enzyme.only!(subtree, offset) + end + + Enzyme.merge!(tt, subtree) + end + Enzyme.canonicalize!(tt, sizeof(JLT), dl) + return tt +end + +end # module diff --git a/test/Project.toml b/test/Project.toml index 76e1e22a..23060348 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,8 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/runtests.jl b/test/runtests.jl index 4df72b2b..a9d5ae81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,25 @@ using GPUArrays, Test, Pkg + +@testset "Enzyme JLArray: TypeTree" begin + + using Enzyme + using JLArrays + using LLVM + + import Enzyme: typetree, TypeTree, API, make_zero + + ctx = LLVM.Context() + dl = string(LLVM.DataLayout(LLVM.JITTargetMachine())) + + tt(T) = string(typetree(T, ctx, dl)) + @test tt(JLArray{Float64, 1}) == "{[0]:Pointer, [0,0]:Pointer, [0,0,-1]:Pointer, [0,0,0,0]:Pointer, [0,0,0,0,-1]:Float@double, [0,0,0,8]:Integer, [0,0,0,9]:Integer, [0,0,0,10]:Integer, [0,0,0,11]:Integer, [0,0,0,12]:Integer, [0,0,0,13]:Integer, [0,0,0,14]:Integer, [0,0,0,15]:Integer, [0,0,0,16]:Integer, [0,0,0,17]:Integer, [0,0,0,18]:Integer, [0,0,0,19]:Integer, [0,0,0,20]:Integer, [0,0,0,21]:Integer, [0,0,0,22]:Integer, [0,0,0,23]:Integer, [0,0,0,24]:Integer, [0,0,0,25]:Integer, [0,0,0,26]:Integer, [0,0,0,27]:Integer, [0,0,0,28]:Integer, [0,0,0,29]:Integer, [0,0,0,30]:Integer, [0,0,0,31]:Integer, [0,0,0,32]:Integer, [0,0,0,33]:Integer, [0,0,0,34]:Integer, [0,0,0,35]:Integer, [0,0,0,36]:Integer, [0,0,0,37]:Integer, [0,0,0,38]:Integer, [0,0,0,39]:Integer, [0,0,16,-1]:Integer, [0,8]:Integer, [8]:Integer, [9]:Integer, [10]:Integer, [11]:Integer, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer}" +end + + include("testsuite.jl") + @testset "JLArray" begin using JLArrays diff --git a/test/testsuite.jl b/test/testsuite.jl index e7c14646..8f367fad 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -96,6 +96,7 @@ include("testsuite/math.jl") include("testsuite/random.jl") include("testsuite/uniformscaling.jl") include("testsuite/statistics.jl") +include("testsuite/enzyme.jl") """ Runs the entire GPUArrays test suite on array type `AT` diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl new file mode 100644 index 00000000..49b038f3 --- /dev/null +++ b/test/testsuite/enzyme.jl @@ -0,0 +1,17 @@ +using Enzyme + +function scalarfirst(x) + @allowscalar x[1] +end + +@testsuite "Enzyme" (AT, eltypes)->begin + for ET in eltypes + T = AT{ET} + @testset "Forward $ET" begin + x = T(ones(3)) + dx = T(3*ones(3)) + res = autodiff(Forward, scalarfirst, Duplicated(x, dx)) + @test approx(res, 3) + end + end +end