From 76727ed0fb19e2013642c6957433b1cee6f4ae1b Mon Sep 17 00:00:00 2001 From: fda-tome Date: Wed, 5 Jun 2024 16:55:05 -0300 Subject: [PATCH 01/34] DArray: Trapezoidal and Triangular wrappers --- src/Dagger.jl | 8 ++- src/array/trapezoidal.jl | 142 +++++++++++++++++++++++++++++++++++++++ src/array/triangular.jl | 83 +++++++++++++++++++++++ 3 files changed, 231 insertions(+), 2 deletions(-) create mode 100644 src/array/trapezoidal.jl create mode 100644 src/array/triangular.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 83e2058e6..b52a96e23 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -7,12 +7,12 @@ import SparseArrays: sprand, SparseMatrixCSC import MemPool import MemPool: DRef, FileRef, poolget, poolset -import Base: collect, reduce +import Base: collect, reduce, require_one_based_indexing import Distributed import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remotecall, remotecall_wait, remotecall_fetch import LinearAlgebra -import LinearAlgebra: Adjoint, BLAS, Diagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, diagind, ishermitian, issymmetric +import LinearAlgebra: Adjoint, BLAS, Diagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, diagind, ishermitian, issymmetric, chkstride1 import UUIDs: UUID, uuid4 @@ -77,8 +77,12 @@ include("array/setindex.jl") include("array/matrix.jl") include("array/sparse_partition.jl") include("array/sort.jl") + +# Linear algebra include("array/linalg.jl") include("array/mul.jl") +include("array/trapezoidal.jl") +include("array/triangular.jl") include("array/cholesky.jl") # Visualization diff --git a/src/array/trapezoidal.jl b/src/array/trapezoidal.jl new file mode 100644 index 000000000..3497fa26a --- /dev/null +++ b/src/array/trapezoidal.jl @@ -0,0 +1,142 @@ +export LowerTrapezoidal, UnitLowerTrapezoidal, UpperTrapezoidal, UnitUpperTrapezoidal, trau!, trau, tral!, tral +import LinearAlgebra: triu!, tril!, triu, tril +abstract type AbstractTrapezoidal{T} <: AbstractMatrix{T} end + +# First loop through all methods that don't need special care for upper/lower and unit diagonal +for t in (:LowerTrapezoidal, :UnitLowerTrapezoidal, :UpperTrapezoidal, :UnitUpperTrapezoidal) + @eval begin + struct $t{T,S<:AbstractMatrix{T}} <: AbstractTrapezoidal{T} + data::S + + function $t{T,S}(data) where {T,S<:AbstractMatrix{T}} + Base.require_one_based_indexing(data) + new{T,S}(data) + end + end + $t(A::$t) = A + $t{T}(A::$t{T}) where {T} = A + $t(A::AbstractMatrix) = $t{eltype(A), typeof(A)}(A) + $t{T}(A::AbstractMatrix) where {T} = $t(convert(AbstractMatrix{T}, A)) + $t{T}(A::$t) where {T} = $t(convert(AbstractMatrix{T}, A.data)) + + AbstractMatrix{T}(A::$t) where {T} = $t{T}(A) + AbstractMatrix{T}(A::$t{T}) where {T} = copy(A) + + Base.size(A::$t) = size(A.data) + Base.axes(A::$t) = axes(A.data) + + Base.similar(A::$t, ::Type{T}) where {T} = $t(similar(parent(A), T)) + Base.similar(A::$t, ::Type{T}, dims::Dims{N}) where {T,N} = similar(parent(A), T, dims) + Base.parent(A::$t) = A.data + + Base.copy(A::$t) = $t(copy(A.data)) + + Base.real(A::$t{<:Real}) = A + Base.real(A::$t{<:Complex}) = (B = real(A.data); $t(B)) + end +end + +Base.getindex(A::UnitLowerTrapezoidal{T}, i::Integer, j::Integer) where {T} = + i > j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : ifelse(i == j, oneunit(T), zero(T)) +Base.getindex(A::LowerTrapezoidal, i::Integer, j::Integer) = +i >= j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : zero(eltype(A.data)) +Base.getindex(A::UnitUpperTrapezoidal{T}, i::Integer, j::Integer) where {T} = + i < j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : ifelse(i == j, oneunit(T), zero(T)) +Base.getindex(A::UpperTrapezoidal, i::Integer, j::Integer) = +i <= j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : zero(eltype(A.data)) + +function _DiagBuild(blockdom::Tuple, alloc::AbstractMatrix{T}, diag::Vector{Tuple{Int,Int}}, transform::Function) where {T} + diagind = findfirst(x-> x[1] in blockdom[1] && x[2] in blockdom[2], diag) + blockind = (diag[diagind][1] - first(blockdom[1]) + 1, diag[diagind][2] - first(blockdom[2]) + 1) + return Dagger.@spawn transform(alloc, blockind[2] - blockind[1]) +end + +function _GenericTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, wrap, k::Integer, dims) + d = ArrayDomain(map(x->1:x, dims)) + dc = partition(p, d) + if f isa UndefInitializer + f = (eltype, x...) -> Array{eltype}(undef, x...) + end + m, n = dims + if k < 0 + diag = [(i-k, i) for i in 1:min(m, n)] + else + diag = [(i, i+k) for i in 1:min(m, n)] + end + thunks = [] + alloc(sz) = f(eltype, sz) + transform = (wrap == LowerTrapezoidal) ? tril! : triu! + compar = (wrap == LowerTrapezoidal) ? (>) : (<) + for c in dc + sz = size(c) + if any(x -> x[1] in c.indexes[1] && x[2] in c.indexes[2], diag) + push!(thunks, _DiagBuild(c.indexes, alloc(sz), diag, transform)) + else + mt, nt = k<0 ? (first(c.indexes[1]), first(c.indexes[2])-k) : (first(c.indexes[1])+k, first(c.indexes[2])) + if compar(mt, nt) + push!(thunks, Dagger.@spawn alloc(sz)) + else + push!(thunks, Dagger.@spawn zeros(eltype, sz)) + end + end + end + thunks = reshape(thunks, size(dc)) + return wrap(Dagger.DArray(eltype, d, dc, thunks, p)) +end + +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) + +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) + +function _GenericTra!(A::Dagger.DArray{T, 2}, wrap::Function, k::Integer) where {T} + d = A.domain + dc = A.subdomains + Ac = A.chunks + m, n = size(A) + if k < 0 + diag = [(i-k, i) for i in 1:min(m, n)] + else + diag = [(i, i+k) for i in 1:min(m, n)] + end + compar = (wrap == tril!) ? (≤) : (≥) + for ind in CartesianIndices(dc) + sz = size(dc[ind]) + if any(x -> x[1] in dc[ind].indexes[1] && x[2] in dc[ind].indexes[2], diag) + Ac[ind] = _DiagBuild(dc[ind].indexes, fetch(Ac[ind]), diag, wrap) + else + mt, nt = k<0 ? (first(dc[ind].indexes[1]), first(dc[ind].indexes[2])-k) : (first(dc[ind].indexes[1])+k, first(dc[ind].indexes[2])) + if compar(mt, nt) + Ac[ind] = Dagger.@spawn zeros(T, sz) + end + end + end + return A +end + + + +trau!(A::Dagger.DArray{T,2}, k::Integer) where {T} = _GenericTra!(A, triu!, k) +trau!(A::Dagger.DArray{T,2}) where {T} = _GenericTra!(A, triu!, 0) +trau(A::Dagger.DArray{T,2}) where {T} = trau!(copy(A)) +trau(A::Dagger.DArray{T,2}, k::Integer) where {T} = trau!(copy(A), k) + +tral!(A::Dagger.DArray{T,2}, k::Integer) where {T} = _GenericTra!(A, tril!, k) +tral!(A::Dagger.DArray{T,2}) where {T} = _GenericTra!(A, tril!, 0) +tral(A::Dagger.DArray{T,2}) where {T} = tral!(copy(A)) +tral(A::Dagger.DArray{T,2}, k::Integer) where {T} = tral!(copy(A), k) + +#TODO: map, reduce, sum, mean, prod, reducedim, collect, distribute diff --git a/src/array/triangular.jl b/src/array/triangular.jl new file mode 100644 index 000000000..47613df2d --- /dev/null +++ b/src/array/triangular.jl @@ -0,0 +1,83 @@ +export LowerTriangular, UpperTriangular, tril, triu, tril!, triu! + +function _GenericTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, compar::Function, wrap, dims) + @assert dims[1] == dims[2] "matrix is not square: dimensions are $dims (try using trapezoidal)" + d = ArrayDomain(map(x->1:x, dims)) + dc = partition(p, d) + if f isa UndefInitializer + f = (eltype, x...) -> Array{eltype}(undef, x...) + end + diag = [(i,i) for i in 1:min(size(d)...)] + thunks = [] + alloc(sz) = f(eltype, sz) + transform = (wrap == LowerTrapezoidal) ? tril! : triu! + compar = (wrap == LowerTrapezoidal) ? (>) : (<) + for c in dc + sz = size(c) + if any(x -> x[1] in c.indexes[1] && x[2] in c.indexes[2], diag) + push!(thunks, _DiagBuild(c.indexes, alloc(sz), diag, transform)) + else + if compar(first(c.indexes[1]), first(c.indexes[2])) + push!(thunks, Dagger.@spawn alloc(sz)) + else + push!(thunks, Dagger.@spawn zeros(eltype, sz)) + end + end + end + thunks = reshape(thunks, size(dc)) + return wrap(Dagger.DArray(eltype, d, dc, thunks, p)) +end + +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) + +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) + +function _GenericTri!(A::Dagger.DArray{T, 2}, wrap) where {T} + LinearAlgebra.checksquare(A) + d = A.domain + dc = A.subdomains + Ac = A.chunks + diag = [(i,i) for i in 1:min(size(d)...)] + compar = (wrap == tril!) ? (>) : (<) + for ind in CartesianIndices(dc) + sz = size(dc[ind]) + if any(x -> x[1] in dc[ind].indexes[1] && x[2] in dc[ind].indexes[2], diag) + Ac[ind] = _DiagBuild(dc[ind].indexes, fetch(Ac[ind]), diag, wrap) + else + if compar(first(dc[ind].indexes[2]), first(dc[ind].indexes[1])) + Ac[ind] = Dagger.@spawn zeros(T, sz) + end + end + end + return A +end + +function LinearAlgebra.triu!(A::Dagger.DArray{T,2}) where {T} + if size(A, 1) != size(A, 2) + trau!(A) + else + _GenericTri!(A, triu!) + end +end +LinearAlgebra.triu(A::Dagger.DArray{T,2}) where {T} = triu!(copy(A)) + +function LinearAlgebra.tril!(A::Dagger.DArray{T,2}) where {T} + if size(A, 1) != size(A, 2) + tral!(A) + else + _GenericTri!(A, tril!) + end +end +LinearAlgebra.tril(A::Dagger.DArray{T,2}) where {T} = tril!(copy(A)) + +#TODO: matmul + + + + From 8d2908a2a96e1ca377f9a6ed3d60c8c432b0dc77 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Wed, 5 Jun 2024 17:04:19 -0300 Subject: [PATCH 02/34] DArray: UndefInitializer --- src/array/alloc.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/array/alloc.jl b/src/array/alloc.jl index 5b881bf0d..3457e9ac7 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -31,6 +31,22 @@ end const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks} +function DArray{T}(::UndefInitializer, p::Blocks, dims) where {T} + d = ArrayDomain(map(x->1:x, dims)) + part = partition(p, d) + f = function (_, T, sz) + Array{T, length(sz)}(undef, sz...) + end + a = AllocateArray(T, f, d, part, p) + return _to_darray(a) +end + +DArray(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) = DArray{Float64}(undef, p, dims) +DArray(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) = DArray{Float64}(undef, p, dims) +DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) where {T} = DArray{T}(undef, p, dims) +DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) where {T} = DArray{T}(undef, p, dims) +DArray{T}(::UndefInitializer, p::AutoBlocks, dims::Tuple) where {T} = DArray{T}(undef, auto_blocks(dims), dims) + function Base.rand(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) a = AllocateArray(eltype, (_, x...) -> rand(x...), d, partition(p, d), p) From b41b2b66d1476a894495d63abb26ae3fa873e5fe Mon Sep 17 00:00:00 2001 From: fda-tome Date: Wed, 5 Jun 2024 17:18:03 -0300 Subject: [PATCH 03/34] DArray: slicing bug fix --- src/array/darray.jl | 22 ++++++++++++++++++---- src/array/indexing.jl | 19 ++++++++++++++++++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/array/darray.jl b/src/array/darray.jl index 6207ee245..67cc57017 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -63,6 +63,10 @@ ArrayDomain((1:15), (1:80)) alignfirst(a::ArrayDomain) = ArrayDomain(map(r->1:length(r), indexes(a))) +alignfirst(a::CartesianIndices{N}) where N = + ArrayDomain(map(r->1:length(r), a.indices)) + + function size(a::ArrayDomain, dim) idxs = indexes(a) length(idxs) < dim ? 1 : length(idxs[dim]) @@ -351,7 +355,7 @@ function group_indices(cumlength, idxs::AbstractRange) end _cumsum(x::AbstractArray) = length(x) == 0 ? Int[] : cumsum(x) -function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}; slice::Bool=false) where N groups = map(group_indices, subdmns.cumlength, indexes(d)) sz = map(length, groups) pieces = Array{Any}(undef, sz) @@ -359,21 +363,31 @@ function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d: idx_and_dmn = map(getindex, groups, i.I) idx = map(x->x[1], idx_and_dmn) dmn = ArrayDomain(map(x->x[2], idx_and_dmn)) - pieces[i] = Dagger.@spawn getindex(ps[idx...], project(subdmns[idx...], dmn)) + if slice + pieces[i] = Dagger.@spawn getindex(ps[idx...], project(subdmns[idx...], dmn)) + else + pieces[i] = Dagger.@spawn view(ps[idx...], project(subdmns[idx...], dmn).indexes...) + end end out_cumlength = map(g->_cumsum(map(x->length(x[2]), g)), groups) out_dmn = DomainBlocks(ntuple(x->1,Val(N)), out_cumlength) return pieces, out_dmn end -function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}) where {N,S} +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}; slice::Bool=false) where {N,S} if S != 1 throw(BoundsError(A, d.indexes)) end inds = CartesianIndices(A)[d.indexes...] new_d = ntuple(i->first(inds).I[i]:last(inds).I[i], N) - return lookup_parts(A, ps, subdmns, ArrayDomain(new_d)) + return lookup_parts(A, ps, subdmns, ArrayDomain(new_d); slice) end +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::CartesianIndices; slice::Bool=false) where N + return lookup_parts(A, ps, subdmns, ArrayDomain(d.indices); slice) +end + + + """ Base.fetch(c::DArray) diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbff..505e04f87 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -10,6 +10,16 @@ end GetIndex(input::ArrayOp, idx::Tuple) = GetIndex{eltype(input), ndims(input)}(input, idx) +function flatten(subdomains, subchunks, partitioning) + valdim = findfirst(j -> j != 1:1, subdomains[1].indexes) + flatc = [] + flats = Array{ArrayDomain{1, Tuple{UnitRange{Int64}}}}(undef, 0) + map(x -> push!(flats, ArrayDomain(x.indexes[valdim])), subdomains) + map(x -> push!(flatc, x), subchunks) + newb = Blocks(partitioning.blocksize[valdim]) + return flats, flatc, newb +end + function stage(ctx::Context, gidx::GetIndex) inp = stage(ctx, gidx.input) @@ -21,7 +31,14 @@ function stage(ctx::Context, gidx::GetIndex) end for i in 1:length(gidx.idx)] # Figure out output dimension - view(inp, ArrayDomain(idxs)) + d = ArrayDomain(idxs) + subchunks, subdomains = Dagger.lookup_parts(inp, chunks(inp), domainchunks(inp), d; slice = true) + d1 = alignfirst(d) + newb = inp.partitioning + if ndims(d1) != ndims(subdomains) + subdomains, subchunks, newb = flatten(subdomains, subchunks, inp.partitioning) + end + DArray(eltype(inp), d1, subdomains, subchunks, newb) end function size(x::GetIndex) From 400286c0e96809cf0627122465b6d52fba5dc903 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 08:30:24 -0300 Subject: [PATCH 04/34] DArray: Tile QR Implementation --- src/Dagger.jl | 1 + src/array/coreblas/coreblas_gemm.jl | 21 +++ src/array/coreblas/coreblas_geqrt.jl | 32 ++++ src/array/coreblas/coreblas_ormqr.jl | 45 +++++ src/array/coreblas/coreblas_tsmqr.jl | 49 ++++++ src/array/coreblas/coreblas_tsqrt.jl | 35 ++++ src/array/coreblas/coreblas_ttmqr.jl | 51 ++++++ src/array/coreblas/coreblas_ttqrt.jl | 32 ++++ src/array/qr.jl | 241 +++++++++++++++++++++++++++ test/array/linalg/qr.jl | 36 ++++ test/runtests.jl | 1 + 11 files changed, 544 insertions(+) create mode 100644 src/array/coreblas/coreblas_gemm.jl create mode 100644 src/array/coreblas/coreblas_geqrt.jl create mode 100644 src/array/coreblas/coreblas_ormqr.jl create mode 100644 src/array/coreblas/coreblas_tsmqr.jl create mode 100644 src/array/coreblas/coreblas_tsqrt.jl create mode 100644 src/array/coreblas/coreblas_ttmqr.jl create mode 100644 src/array/coreblas/coreblas_ttqrt.jl create mode 100644 src/array/qr.jl create mode 100644 test/array/linalg/qr.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index b52a96e23..a447d6442 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -84,6 +84,7 @@ include("array/mul.jl") include("array/trapezoidal.jl") include("array/triangular.jl") include("array/cholesky.jl") +include("array/qr.jl") # Visualization include("visualization.jl") diff --git a/src/array/coreblas/coreblas_gemm.jl b/src/array/coreblas/coreblas_gemm.jl new file mode 100644 index 000000000..ad77be40d --- /dev/null +++ b/src/array/coreblas/coreblas_gemm.jl @@ -0,0 +1,21 @@ +using libblastrampoline_jll +using LinearAlgebra +using libcoreblas_jll + +for (gemm, T) in + ((:coreblas_dgemm, Float64), + (:coreblas_sgemm, Float32), + (:coreblas_cgemm, ComplexF32), + (:coreblas_zgemm, ComplexF64)) + @eval begin + function coreblas_gemm!(transa::Int64, transb::Int64, + alpha::$T, A::AbstractMatrix{$T}, B::AbstractMatrix{$T}, beta::$T, C::AbstractMatrix{$T}) + m, k = size(A) + k, n = size(B) + ccall(($gemm, "libcoreblas.so"), Cvoid, + (Int64, Int64, Int64, Int64, Int64, $T, Ptr{$T}, Int64, Ptr{$T}, Int64, + $T, Ptr{$T}, Int64), + transa, transb, m, n, k, alpha, A, m, B, k, beta, C, m) + end + end +end diff --git a/src/array/coreblas/coreblas_geqrt.jl b/src/array/coreblas/coreblas_geqrt.jl new file mode 100644 index 000000000..99f86d442 --- /dev/null +++ b/src/array/coreblas/coreblas_geqrt.jl @@ -0,0 +1,32 @@ +for (geqrt, T) in + ((:coreblas_dgeqrt, Float64), + (:coreblas_sgeqrt, Float32), + (:coreblas_cgeqrt, ComplexF32), + (:coreblas_zgeqrt, ComplexF64)) + @eval begin + function coreblas_geqrt!(A::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}) + require_one_based_indexing(A, Tau) + chkstride1(A) + m, n = size(A) + ib, nb = size(Tau) + lda = max(1, stride(A,2)) + ldt = max(1, stride(Tau,2)) + work = Vector{$T}(undef, (ib)*n) + ttau = Vector{$T}(undef, n) + + err = ccall(($(QuoteNode(geqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Ptr{$T}), + m, n, ib, + A, lda, + Tau, ldt, + ttau, work) + if err != 0 + throw(ArgumentError("coreblas_geqrt failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_ormqr.jl b/src/array/coreblas/coreblas_ormqr.jl new file mode 100644 index 000000000..5bcb7b1bb --- /dev/null +++ b/src/array/coreblas/coreblas_ormqr.jl @@ -0,0 +1,45 @@ +for (geormqr, T) in + ((:coreblas_dormqr, Float64), + (:coreblas_sormqr, Float32), + (:coreblas_zunmqr, ComplexF64), + (:coreblas_cunmqr, ComplexF32)) + @eval begin + function coreblas_ormqr!(side::Char, trans::Char, A::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}, C::AbstractMatrix{$T}) + + m, n = size(C) + ib, nb = size(Tau) + k = nb + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + sidenum = side == 'L' ? 141 : 142 + + lda = max(1, stride(A,2)) + ldt = max(1, stride(Tau,2)) + ldc = max(1, stride(C,2)) + ldwork = side == 'L' ? n : m + work = Vector{$T}(undef, ib*nb) + + + err = ccall(($(QuoteNode(geormqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64), + sidenum, transnum, + m, n, + k, ib, + A, lda, + Tau, ldt, + C, ldc, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_ormqr failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_tsmqr.jl b/src/array/coreblas/coreblas_tsmqr.jl new file mode 100644 index 000000000..1f647f7c1 --- /dev/null +++ b/src/array/coreblas/coreblas_tsmqr.jl @@ -0,0 +1,49 @@ +for (getsmqr, T) in + ((:coreblas_dtsmqr, Float64), + (:coreblas_ctsmqr, ComplexF32), + (:coreblas_ztsmqr, ComplexF64), + (:coreblas_stsmqr, Float32)) + @eval begin + function coreblas_tsmqr!(side::Char, trans::Char, A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, V::AbstractMatrix{$T}, Tau::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(Tau) + k = nb + + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + + sidenum = side == 'L' ? 141 : 142 + + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldv = max(1, stride(V,2)) + ldt = max(1, stride(Tau,2)) + ldwork = side == 'L' ? ib : m1 + work = Vector{$T}(undef, ib*nb) + + + err = ccall(($(QuoteNode(getsmqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64), + sidenum, transnum, + m1, n1, + m2, n2, + k, ib, + A1, lda1, + A2, lda2, + V, ldv, + Tau, ldt, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_tsmqr failed. Error number: $err")) + end + end + end +end diff --git a/src/array/coreblas/coreblas_tsqrt.jl b/src/array/coreblas/coreblas_tsqrt.jl new file mode 100644 index 000000000..e644465a3 --- /dev/null +++ b/src/array/coreblas/coreblas_tsqrt.jl @@ -0,0 +1,35 @@ + +for (getsqrt,T) in + ((:coreblas_dtsqrt, Float64), + (:coreblas_stsqrt, Float32), + (:coreblas_ctsqrt, ComplexF32), + (:coreblas_ztsqrt, ComplexF64)) + @eval begin + function coreblas_tsqrt!(A1::AbstractMatrix{$T}, A2::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}) + m = size(A2)[1] + n = size(A1)[2] + ib, nb = size(Tau) + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldt = max(1, stride(Tau,2)) + work = Vector{$T}(undef, (ib)*n) + ttau = Vector{$T}(undef, n) + + err = ccall(($(QuoteNode(getsqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), + m, n, ib, + A1, lda1, + A2, lda2, + Tau, ldt, + ttau, work) + if err != 0 + throw(ArgumentError("coreblas_tsqrt failed. Error number: $err")) + end + end + end +end + + diff --git a/src/array/coreblas/coreblas_ttmqr.jl b/src/array/coreblas/coreblas_ttmqr.jl new file mode 100644 index 000000000..32c0f4eb3 --- /dev/null +++ b/src/array/coreblas/coreblas_ttmqr.jl @@ -0,0 +1,51 @@ +using libcoreblas_jll +for (gettmqr, T) in + ((:coreblas_dttmqr, Float64), + (:coreblas_sttmqr, Float32), + (:coreblas_cttmqr, ComplexF32), + (:coreblas_zttmqr, ComplexF64)) + @eval begin + function coreblas_ttmqr!(side::Char, trans::Char, A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, V::AbstractMatrix{$T}, Tau::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(Tau) + k=nb + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + + sidenum = side == 'L' ? 141 : 142 + + ldv = max(1, stride(V,2)) + ldt = max(1, stride(Tau,2)) + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldwork = side == 'L' ? max(1,ib) : max(1,m1) + workdim = side == 'L' ? n1 : ib + work = Vector{$T}(undef, ldwork*workdim) + + err = ccall(($(QuoteNode(gettmqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64), + sidenum, transnum, + m1, n1, + m2, n2, + k, ib, + A1, lda1, + A2, lda2, + V, ldv, + Tau, ldt, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_ttmqr, failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_ttqrt.jl b/src/array/coreblas/coreblas_ttqrt.jl new file mode 100644 index 000000000..f90373517 --- /dev/null +++ b/src/array/coreblas/coreblas_ttqrt.jl @@ -0,0 +1,32 @@ + +for (gettqrt, T) in + ((:coreblas_dttqrt, Float64), + (:coreblas_sttqrt, Float32), + (:coreblas_cttqrt, ComplexF32), + (:coreblas_zttqrt, ComplexF64)) + @eval begin + function coreblas_ttqrt!(A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, triT::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(triT) + + lwork = nb + ib*nb + tau = Vector{$T}(undef, nb) + work = Vector{$T}(undef, (ib+1)*nb) + lda1 = max(1, stride(A1, 2)) + lda2 = max(1, stride(A2, 2)) + ldt = max(1, stride(triT, 2)) + + + err = ccall(($(QuoteNode(gettqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), + m1, n1, ib, A1, lda1, A2, lda2, triT, ldt, tau, work) + + if err != 0 + throw(ArgumentError("coreblas_ttqrt failed. Error number: $err")) + end + end + end +end diff --git a/src/array/qr.jl b/src/array/qr.jl new file mode 100644 index 000000000..cc825157e --- /dev/null +++ b/src/array/qr.jl @@ -0,0 +1,241 @@ +export geqrf!, porgqr!, pormqr!, cageqrf! +import LinearAlgebra: QRCompactWY, AdjointQ, BlasFloat, QRCompactWYQ, AbstractQ, StridedVecOrMat, I +import Base.:* +include("coreblas/coreblas_ormqr.jl") +include("coreblas/coreblas_ttqrt.jl") +include("coreblas/coreblas_ttmqr.jl") +include("coreblas/coreblas_geqrt.jl") +include("coreblas/coreblas_tsqrt.jl") +include("coreblas/coreblas_tsmqr.jl") + +(*)(Q::QRCompactWYQ{T, M}, b::Number) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b +(*)(b::Number, Q::QRCompactWYQ{T, M}) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b + +(*)(Q::AdjointQ{T, QRCompactWYQ{T, M, C}}, b::Number) where {T<:Number, M<:DMatrix{T}, C<:LowerTrapezoidal{T, M}} = DMatrix(Q) * b +(*)(b::Number, Q::AdjointQ{T, QRCompactWYQ{T, M, C}}) where {T<:Number, M<:DMatrix{T}, C<:LowerTrapezoidal{T, M}} = DMatrix(Q) * b + +LinearAlgebra.lmul!(B::QRCompactWYQ{T, M}, A::M) where {T, M<:DMatrix{T}} = pormqr!('L', 'N', B.factors, B.T, A) +function LinearAlgebra.lmul!(B::AdjointQ{T, <:QRCompactWYQ{T, M}}, A::M) where {T, M<:Dagger.DMatrix{T}} + trans = T <: Complex ? 'C' : 'T' + pormqr!('L', trans, B.Q.factors, B.Q.T, A) +end + +LinearAlgebra.rmul!(A::Dagger.DMatrix{T}, B::QRCompactWYQ{T, M}) where {T, M<:Dagger.DMatrix{T}} = pormqr!('R', 'N', B.factors, B.T, A) +function LinearAlgebra.rmul!(A::Dagger.DArray{T,2}, B::AdjointQ{T, <:QRCompactWYQ{T, M}}) where {T, M<:Dagger.DMatrix{T}} + trans = T <: Complex ? 'C' : 'T' + pormqr!('R', trans, B.Q.factors, B.Q.T, A) +end + +function Dagger.DMatrix(Q::QRCompactWYQ{T, <:Dagger.DArray{T, 2}}) where {T} + DQ = distribute(Matrix(I*one(T), size(Q.factors)[1], size(Q.factors)[1]), Q.factors.partitioning) + porgqr!('N', Q.factors, Q.T, DQ) + return DQ +end + +function Dagger.DMatrix(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DArray{T, 2}}}) where {T} + DQ = distribute(Matrix(I*one(T), size(AQ.Q.factors)[1], size(AQ.Q.factors)[1]), AQ.Q.factors.partitioning) + trans = T <: Complex ? 'C' : 'T' + porgqr!(trans, AQ.Q.factors, AQ.Q.T, DQ) + return DQ +end + +Base.collect(Q::QRCompactWYQ{T, <:Dagger.DArray{T, 2}}) where {T} = collect(Dagger.DMatrix(Q)) +Base.collect(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DArray{T, 2}}}) where {T} = collect(Dagger.DMatrix(AQ)) + +function pormqr!(side::Char, trans::Char, A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}, C::Dagger.DArray{T, 2}) where {T<:Number} + m, n = size(C) + Ac = A.chunks + Tc = Tm.data.chunks + Cc = C.chunks + + Amt, Ant = size(Ac) + Tmt, Tnt = size(Tc) + Cmt, Cnt = size(Cc) + minMT = min(Amt, Ant) + + Dagger.spawn_datadeps() do + if side == 'L' + if (trans == 'T' || trans == 'C') + for k in 1:minMT + for n in 1:Cnt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[k,n])) + end + for m in k+1:Cmt, n in 1:Cnt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[k, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + end + if trans == 'N' + for k in minMT:-1:1 + for m in Cmt:-1:k+1, n in 1:Cnt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[k, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in 1:Cnt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[k, n])) + end + end + end + else + if side == 'R' + if trans == 'T' || trans == 'C' + for k in minMT:-1:1 + for n in Cmt:-1:k+1, m in 1:Cmt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[m, k]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + for m in 1:Cmt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[m, k])) + end + end + end + if trans == 'N' + for k in 1:minMT + for m in 1:Cmt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[m, k])) + end + for n in k+1:Cmt, m in 1:Cmt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[m, k]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + end + end + end + end + end + return C +end + +function cageqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}; static::Bool=true, traversal::Symbol=:inorder, p::Int64=1) where {T<: Number} + if p == 1 + return geqrf!(A, Tm; static, traversal) + end + Ac = A.chunks + mt, nt = size(Ac) + @assert mt % p == 0 "Number of tiles must be divisible by the number of domains" + mtd = Int64(mt/p) + Tc = Tm.data.chunks + proot = 1 + nxtmt = mtd + trans = T <: Complex ? 'C' : 'T' + Dagger.spawn_datadeps(;static, traversal) do + for k in 1:min(mt, nt) + if k > nxtmt + proot += 1 + nxtmt += mtd + end + for pt in proot:p + ibeg = 1 + (pt-1) * mtd + if pt == proot + ibeg = k + end + Dagger.@spawn coreblas_geqrt!(InOut(Ac[ibeg, k]), Out(Tc[ibeg,k])) + for n in k+1:nt + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[ibeg, k]), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) + end + for m in ibeg+1:(pt * mtd) + Dagger.@spawn coreblas_tsqrt!(InOut(Ac[ibeg, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[ibeg, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + for m in 1:ceil(Int64, log2(p-proot+1)) + p1 = proot + p2 = p1 + 2^(m-1) + while p2 ≤ p + i1 = 1 + (p1-1) * mtd + i2 = 1 + (p2-1) * mtd + if p1 == proot + i1 = k + end + Dagger.@spawn coreblas_ttqrt!(InOut(Ac[i1, k]), InOut(Ac[i2, k]), Out(Tc[i2, k])) + for n in k+1:nt + Dagger.@spawn coreblas_ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), In(Ac[i2, k]), In(Tc[i2, k])) + end + p1 += 2^m + p2 += 2^m + end + end + end + end +end + +function geqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}; static::Bool=true, traversal::Symbol=:inorder) where {T<: Number} + Ac = A.chunks + mt, nt = size(Ac) + Tc = Tm.data.chunks + trans = T <: Complex ? 'C' : 'T' + + Ccopy = Dagger.DArray{T}(undef, A.partitioning, A.partitioning.blocksize[1], min(mt, nt) * A.partitioning.blocksize[2]) + Cc = Ccopy.chunks + Dagger.spawn_datadeps(;static, traversal) do + for k in 1:min(mt, nt) + Dagger.@spawn coreblas_geqrt!(InOut(Ac[k, k]), Out(Tc[k,k])) + # FIXME: This is a hack to avoid aliasing + Dagger.@spawn copyto!(InOut(Cc[1,k]), In(Ac[k, k])) + for n in k+1:nt + #FIXME: Change Cc[1,k] to upper triangular of Ac[k,k] + Dagger.@spawn coreblas_ormqr!('L', trans, In(Cc[1, k]), In(Tc[k,k]), InOut(Ac[k, n])) + end + for m in k+1:mt + Dagger.@spawn coreblas_tsqrt!(InOut(Ac[k, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[k, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + end +end + +function porgqr!(trans::Char, A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}, Q::Dagger.DArray{T, 2}; static::Bool=true, traversal::Symbol=:inorder) where {T<:Number} + Ac = A.chunks + Tc = Tm.data.chunks + Qc = Q.chunks + mt, nt = size(Ac) + qmt, qnt = size(Qc) + + Dagger.spawn_datadeps(;static, traversal) do + if trans == 'N' + for k in min(mt, nt):-1:1 + for m in qmt:-1:k + 1, n in k:qnt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in k:qnt + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + end + else + for k in 1:min(mt, nt) + for n in 1:k + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + for m in k+1:qmt, n in 1:qnt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + end + end +end + +function meas_ws(A::Dagger.DArray{T, 2}, ib::Int64) where {T<: Number} + mb, nb = A.partitioning.blocksize + m, n = size(A) + MT = (mod(m,nb)==0) ? floor(Int64, (m / mb)) : floor(Int64, (m / mb) + 1) + NT = (mod(n,nb)==0) ? floor(Int64,(n / nb)) : floor(Int64, (n / nb) + 1) * 2 + lm = ib * MT; + ln = nb * NT; + lm, ln +end + +function LinearAlgebra.qr!(A::Dagger.DArray{T, 2}; ib::Int64=1, p::Int64=1) where {T<:Number} + lm, ln = meas_ws(A, ib) + Ac = A.chunks + nb = A.partitioning.blocksize[2] + mt, nt = size(Ac) + st = nb * (nt - 1) + Tm = LowerTrapezoidal(zeros, Blocks(ib, nb), T, st, lm, ln) + geqrf!(A, Tm) + return QRCompactWY(A, Tm); +end + + diff --git a/test/array/linalg/qr.jl b/test/array/linalg/qr.jl new file mode 100644 index 000000000..2a347d164 --- /dev/null +++ b/test/array/linalg/qr.jl @@ -0,0 +1,36 @@ + @testset "Tile QR: $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + ## Square matrices + A = rand(T, 128, 128) + Q, R = qr(A) + DA = distribute(A, Blocks(32,32)) + DQ, DR = qr!(DA) + @test abs.(DQ) ≈ abs.(Q) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * abs.(DQ) ≈ abs.(Q) + @test I * abs.(DQ') ≈ abs.(Q') + + ## Rectangular matrices (block and element wise) + # Tall Element and Block + A = rand(T, 128, 64) + Q, R = qr(A) + DA = distribute(A, Blocks(32,32)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') + + # Wide Element and Block + A = rand(T, 64, 128) + Q, R = qr(A) + DA = distribute(A, Blocks(16,16)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') +end diff --git a/test/runtests.jl b/test/runtests.jl index ba646b159..3742c3315 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ tests = [ ("Array - MapReduce", "array/mapreduce.jl"), ("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"), ("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"), + ("Array - LinearAlgebra - QR", "array/linalg/qr.jl"), ("Caching", "cache.jl"), ("Disk Caching", "diskcaching.jl"), ("File IO", "file-io.jl"), From 99c7ba6414f3610f80dbeafb995ede715c745742 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 09:00:24 -0300 Subject: [PATCH 05/34] Rebasing commit, solving conflicts --- src/Dagger.jl | 7 +- src/array/trapezoidal.jl | 142 +++++++++++++++++++++++++++++++++++++++ src/array/triangular.jl | 83 +++++++++++++++++++++++ 3 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/array/trapezoidal.jl create mode 100644 src/array/triangular.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 97f7dd44d..7211ad277 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -7,13 +7,14 @@ import SparseArrays: sprand, SparseMatrixCSC import MemPool import MemPool: DRef, FileRef, poolget, poolset -import Base: collect, reduce +import Base: collect, reduce, require_one_based_indexing import Distributed import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remotecall, remotecall_wait, remotecall_fetch import LinearAlgebra import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric + import UUIDs: UUID, uuid4 if !isdefined(Base, :ScopedValues) @@ -77,8 +78,12 @@ include("array/setindex.jl") include("array/matrix.jl") include("array/sparse_partition.jl") include("array/sort.jl") + +# Linear algebra include("array/linalg.jl") include("array/mul.jl") +include("array/trapezoidal.jl") +include("array/triangular.jl") include("array/cholesky.jl") # Visualization diff --git a/src/array/trapezoidal.jl b/src/array/trapezoidal.jl new file mode 100644 index 000000000..3497fa26a --- /dev/null +++ b/src/array/trapezoidal.jl @@ -0,0 +1,142 @@ +export LowerTrapezoidal, UnitLowerTrapezoidal, UpperTrapezoidal, UnitUpperTrapezoidal, trau!, trau, tral!, tral +import LinearAlgebra: triu!, tril!, triu, tril +abstract type AbstractTrapezoidal{T} <: AbstractMatrix{T} end + +# First loop through all methods that don't need special care for upper/lower and unit diagonal +for t in (:LowerTrapezoidal, :UnitLowerTrapezoidal, :UpperTrapezoidal, :UnitUpperTrapezoidal) + @eval begin + struct $t{T,S<:AbstractMatrix{T}} <: AbstractTrapezoidal{T} + data::S + + function $t{T,S}(data) where {T,S<:AbstractMatrix{T}} + Base.require_one_based_indexing(data) + new{T,S}(data) + end + end + $t(A::$t) = A + $t{T}(A::$t{T}) where {T} = A + $t(A::AbstractMatrix) = $t{eltype(A), typeof(A)}(A) + $t{T}(A::AbstractMatrix) where {T} = $t(convert(AbstractMatrix{T}, A)) + $t{T}(A::$t) where {T} = $t(convert(AbstractMatrix{T}, A.data)) + + AbstractMatrix{T}(A::$t) where {T} = $t{T}(A) + AbstractMatrix{T}(A::$t{T}) where {T} = copy(A) + + Base.size(A::$t) = size(A.data) + Base.axes(A::$t) = axes(A.data) + + Base.similar(A::$t, ::Type{T}) where {T} = $t(similar(parent(A), T)) + Base.similar(A::$t, ::Type{T}, dims::Dims{N}) where {T,N} = similar(parent(A), T, dims) + Base.parent(A::$t) = A.data + + Base.copy(A::$t) = $t(copy(A.data)) + + Base.real(A::$t{<:Real}) = A + Base.real(A::$t{<:Complex}) = (B = real(A.data); $t(B)) + end +end + +Base.getindex(A::UnitLowerTrapezoidal{T}, i::Integer, j::Integer) where {T} = + i > j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : ifelse(i == j, oneunit(T), zero(T)) +Base.getindex(A::LowerTrapezoidal, i::Integer, j::Integer) = +i >= j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : zero(eltype(A.data)) +Base.getindex(A::UnitUpperTrapezoidal{T}, i::Integer, j::Integer) where {T} = + i < j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : ifelse(i == j, oneunit(T), zero(T)) +Base.getindex(A::UpperTrapezoidal, i::Integer, j::Integer) = +i <= j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : zero(eltype(A.data)) + +function _DiagBuild(blockdom::Tuple, alloc::AbstractMatrix{T}, diag::Vector{Tuple{Int,Int}}, transform::Function) where {T} + diagind = findfirst(x-> x[1] in blockdom[1] && x[2] in blockdom[2], diag) + blockind = (diag[diagind][1] - first(blockdom[1]) + 1, diag[diagind][2] - first(blockdom[2]) + 1) + return Dagger.@spawn transform(alloc, blockind[2] - blockind[1]) +end + +function _GenericTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, wrap, k::Integer, dims) + d = ArrayDomain(map(x->1:x, dims)) + dc = partition(p, d) + if f isa UndefInitializer + f = (eltype, x...) -> Array{eltype}(undef, x...) + end + m, n = dims + if k < 0 + diag = [(i-k, i) for i in 1:min(m, n)] + else + diag = [(i, i+k) for i in 1:min(m, n)] + end + thunks = [] + alloc(sz) = f(eltype, sz) + transform = (wrap == LowerTrapezoidal) ? tril! : triu! + compar = (wrap == LowerTrapezoidal) ? (>) : (<) + for c in dc + sz = size(c) + if any(x -> x[1] in c.indexes[1] && x[2] in c.indexes[2], diag) + push!(thunks, _DiagBuild(c.indexes, alloc(sz), diag, transform)) + else + mt, nt = k<0 ? (first(c.indexes[1]), first(c.indexes[2])-k) : (first(c.indexes[1])+k, first(c.indexes[2])) + if compar(mt, nt) + push!(thunks, Dagger.@spawn alloc(sz)) + else + push!(thunks, Dagger.@spawn zeros(eltype, sz)) + end + end + end + thunks = reshape(thunks, size(dc)) + return wrap(Dagger.DArray(eltype, d, dc, thunks, p)) +end + +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) + +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) + +function _GenericTra!(A::Dagger.DArray{T, 2}, wrap::Function, k::Integer) where {T} + d = A.domain + dc = A.subdomains + Ac = A.chunks + m, n = size(A) + if k < 0 + diag = [(i-k, i) for i in 1:min(m, n)] + else + diag = [(i, i+k) for i in 1:min(m, n)] + end + compar = (wrap == tril!) ? (≤) : (≥) + for ind in CartesianIndices(dc) + sz = size(dc[ind]) + if any(x -> x[1] in dc[ind].indexes[1] && x[2] in dc[ind].indexes[2], diag) + Ac[ind] = _DiagBuild(dc[ind].indexes, fetch(Ac[ind]), diag, wrap) + else + mt, nt = k<0 ? (first(dc[ind].indexes[1]), first(dc[ind].indexes[2])-k) : (first(dc[ind].indexes[1])+k, first(dc[ind].indexes[2])) + if compar(mt, nt) + Ac[ind] = Dagger.@spawn zeros(T, sz) + end + end + end + return A +end + + + +trau!(A::Dagger.DArray{T,2}, k::Integer) where {T} = _GenericTra!(A, triu!, k) +trau!(A::Dagger.DArray{T,2}) where {T} = _GenericTra!(A, triu!, 0) +trau(A::Dagger.DArray{T,2}) where {T} = trau!(copy(A)) +trau(A::Dagger.DArray{T,2}, k::Integer) where {T} = trau!(copy(A), k) + +tral!(A::Dagger.DArray{T,2}, k::Integer) where {T} = _GenericTra!(A, tril!, k) +tral!(A::Dagger.DArray{T,2}) where {T} = _GenericTra!(A, tril!, 0) +tral(A::Dagger.DArray{T,2}) where {T} = tral!(copy(A)) +tral(A::Dagger.DArray{T,2}, k::Integer) where {T} = tral!(copy(A), k) + +#TODO: map, reduce, sum, mean, prod, reducedim, collect, distribute diff --git a/src/array/triangular.jl b/src/array/triangular.jl new file mode 100644 index 000000000..47613df2d --- /dev/null +++ b/src/array/triangular.jl @@ -0,0 +1,83 @@ +export LowerTriangular, UpperTriangular, tril, triu, tril!, triu! + +function _GenericTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, compar::Function, wrap, dims) + @assert dims[1] == dims[2] "matrix is not square: dimensions are $dims (try using trapezoidal)" + d = ArrayDomain(map(x->1:x, dims)) + dc = partition(p, d) + if f isa UndefInitializer + f = (eltype, x...) -> Array{eltype}(undef, x...) + end + diag = [(i,i) for i in 1:min(size(d)...)] + thunks = [] + alloc(sz) = f(eltype, sz) + transform = (wrap == LowerTrapezoidal) ? tril! : triu! + compar = (wrap == LowerTrapezoidal) ? (>) : (<) + for c in dc + sz = size(c) + if any(x -> x[1] in c.indexes[1] && x[2] in c.indexes[2], diag) + push!(thunks, _DiagBuild(c.indexes, alloc(sz), diag, transform)) + else + if compar(first(c.indexes[1]), first(c.indexes[2])) + push!(thunks, Dagger.@spawn alloc(sz)) + else + push!(thunks, Dagger.@spawn zeros(eltype, sz)) + end + end + end + thunks = reshape(thunks, size(dc)) + return wrap(Dagger.DArray(eltype, d, dc, thunks, p)) +end + +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) + +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) + +function _GenericTri!(A::Dagger.DArray{T, 2}, wrap) where {T} + LinearAlgebra.checksquare(A) + d = A.domain + dc = A.subdomains + Ac = A.chunks + diag = [(i,i) for i in 1:min(size(d)...)] + compar = (wrap == tril!) ? (>) : (<) + for ind in CartesianIndices(dc) + sz = size(dc[ind]) + if any(x -> x[1] in dc[ind].indexes[1] && x[2] in dc[ind].indexes[2], diag) + Ac[ind] = _DiagBuild(dc[ind].indexes, fetch(Ac[ind]), diag, wrap) + else + if compar(first(dc[ind].indexes[2]), first(dc[ind].indexes[1])) + Ac[ind] = Dagger.@spawn zeros(T, sz) + end + end + end + return A +end + +function LinearAlgebra.triu!(A::Dagger.DArray{T,2}) where {T} + if size(A, 1) != size(A, 2) + trau!(A) + else + _GenericTri!(A, triu!) + end +end +LinearAlgebra.triu(A::Dagger.DArray{T,2}) where {T} = triu!(copy(A)) + +function LinearAlgebra.tril!(A::Dagger.DArray{T,2}) where {T} + if size(A, 1) != size(A, 2) + tral!(A) + else + _GenericTri!(A, tril!) + end +end +LinearAlgebra.tril(A::Dagger.DArray{T,2}) where {T} = tril!(copy(A)) + +#TODO: matmul + + + + From 185e61186f1ea0bdfe4600342fec0077f85d569c Mon Sep 17 00:00:00 2001 From: fda-tome Date: Wed, 5 Jun 2024 17:04:19 -0300 Subject: [PATCH 06/34] DArray: UndefInitializer --- src/array/alloc.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/array/alloc.jl b/src/array/alloc.jl index 5b881bf0d..3457e9ac7 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -31,6 +31,22 @@ end const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks} +function DArray{T}(::UndefInitializer, p::Blocks, dims) where {T} + d = ArrayDomain(map(x->1:x, dims)) + part = partition(p, d) + f = function (_, T, sz) + Array{T, length(sz)}(undef, sz...) + end + a = AllocateArray(T, f, d, part, p) + return _to_darray(a) +end + +DArray(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) = DArray{Float64}(undef, p, dims) +DArray(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) = DArray{Float64}(undef, p, dims) +DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) where {T} = DArray{T}(undef, p, dims) +DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) where {T} = DArray{T}(undef, p, dims) +DArray{T}(::UndefInitializer, p::AutoBlocks, dims::Tuple) where {T} = DArray{T}(undef, auto_blocks(dims), dims) + function Base.rand(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) a = AllocateArray(eltype, (_, x...) -> rand(x...), d, partition(p, d), p) From afed3562dab055838ee06c448ad43910aabe3a91 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Wed, 5 Jun 2024 17:18:03 -0300 Subject: [PATCH 07/34] DArray: slicing bug fix --- src/array/darray.jl | 22 ++++++++++++++++++---- src/array/indexing.jl | 19 ++++++++++++++++++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/array/darray.jl b/src/array/darray.jl index d4343b8ee..3a02e8143 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -65,6 +65,10 @@ ArrayDomain((1:15), (1:80)) alignfirst(a::ArrayDomain) = ArrayDomain(map(r->1:length(r), indexes(a))) +alignfirst(a::CartesianIndices{N}) where N = + ArrayDomain(map(r->1:length(r), a.indices)) + + function size(a::ArrayDomain, dim) idxs = indexes(a) length(idxs) < dim ? 1 : length(idxs[dim]) @@ -369,7 +373,7 @@ function group_indices(cumlength, idxs::AbstractRange) end _cumsum(x::AbstractArray) = length(x) == 0 ? Int[] : cumsum(x) -function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}; slice::Bool=false) where N groups = map(group_indices, subdmns.cumlength, indexes(d)) sz = map(length, groups) pieces = Array{Any}(undef, sz) @@ -377,21 +381,31 @@ function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d: idx_and_dmn = map(getindex, groups, i.I) idx = map(x->x[1], idx_and_dmn) dmn = ArrayDomain(map(x->x[2], idx_and_dmn)) - pieces[i] = Dagger.@spawn getindex(ps[idx...], project(subdmns[idx...], dmn)) + if slice + pieces[i] = Dagger.@spawn getindex(ps[idx...], project(subdmns[idx...], dmn)) + else + pieces[i] = Dagger.@spawn view(ps[idx...], project(subdmns[idx...], dmn).indexes...) + end end out_cumlength = map(g->_cumsum(map(x->length(x[2]), g)), groups) out_dmn = DomainBlocks(ntuple(x->1,Val(N)), out_cumlength) return pieces, out_dmn end -function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}) where {N,S} +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}; slice::Bool=false) where {N,S} if S != 1 throw(BoundsError(A, d.indexes)) end inds = CartesianIndices(A)[d.indexes...] new_d = ntuple(i->first(inds).I[i]:last(inds).I[i], N) - return lookup_parts(A, ps, subdmns, ArrayDomain(new_d)) + return lookup_parts(A, ps, subdmns, ArrayDomain(new_d); slice) end +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::CartesianIndices; slice::Bool=false) where N + return lookup_parts(A, ps, subdmns, ArrayDomain(d.indices); slice) +end + + + """ Base.fetch(c::DArray) diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbff..505e04f87 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -10,6 +10,16 @@ end GetIndex(input::ArrayOp, idx::Tuple) = GetIndex{eltype(input), ndims(input)}(input, idx) +function flatten(subdomains, subchunks, partitioning) + valdim = findfirst(j -> j != 1:1, subdomains[1].indexes) + flatc = [] + flats = Array{ArrayDomain{1, Tuple{UnitRange{Int64}}}}(undef, 0) + map(x -> push!(flats, ArrayDomain(x.indexes[valdim])), subdomains) + map(x -> push!(flatc, x), subchunks) + newb = Blocks(partitioning.blocksize[valdim]) + return flats, flatc, newb +end + function stage(ctx::Context, gidx::GetIndex) inp = stage(ctx, gidx.input) @@ -21,7 +31,14 @@ function stage(ctx::Context, gidx::GetIndex) end for i in 1:length(gidx.idx)] # Figure out output dimension - view(inp, ArrayDomain(idxs)) + d = ArrayDomain(idxs) + subchunks, subdomains = Dagger.lookup_parts(inp, chunks(inp), domainchunks(inp), d; slice = true) + d1 = alignfirst(d) + newb = inp.partitioning + if ndims(d1) != ndims(subdomains) + subdomains, subchunks, newb = flatten(subdomains, subchunks, inp.partitioning) + end + DArray(eltype(inp), d1, subdomains, subchunks, newb) end function size(x::GetIndex) From d05fcf44cfe1715aea0328692221a734383edc64 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 08:30:24 -0300 Subject: [PATCH 08/34] DArray: Tile QR Implementation --- src/Dagger.jl | 1 + src/array/coreblas/coreblas_gemm.jl | 21 +++ src/array/coreblas/coreblas_geqrt.jl | 32 ++++ src/array/coreblas/coreblas_ormqr.jl | 45 +++++ src/array/coreblas/coreblas_tsmqr.jl | 49 ++++++ src/array/coreblas/coreblas_tsqrt.jl | 35 ++++ src/array/coreblas/coreblas_ttmqr.jl | 51 ++++++ src/array/coreblas/coreblas_ttqrt.jl | 32 ++++ src/array/qr.jl | 241 +++++++++++++++++++++++++++ test/array/linalg/qr.jl | 36 ++++ test/runtests.jl | 1 + 11 files changed, 544 insertions(+) create mode 100644 src/array/coreblas/coreblas_gemm.jl create mode 100644 src/array/coreblas/coreblas_geqrt.jl create mode 100644 src/array/coreblas/coreblas_ormqr.jl create mode 100644 src/array/coreblas/coreblas_tsmqr.jl create mode 100644 src/array/coreblas/coreblas_tsqrt.jl create mode 100644 src/array/coreblas/coreblas_ttmqr.jl create mode 100644 src/array/coreblas/coreblas_ttqrt.jl create mode 100644 src/array/qr.jl create mode 100644 test/array/linalg/qr.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 7211ad277..3c1d92eee 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -85,6 +85,7 @@ include("array/mul.jl") include("array/trapezoidal.jl") include("array/triangular.jl") include("array/cholesky.jl") +include("array/qr.jl") # Visualization include("visualization.jl") diff --git a/src/array/coreblas/coreblas_gemm.jl b/src/array/coreblas/coreblas_gemm.jl new file mode 100644 index 000000000..ad77be40d --- /dev/null +++ b/src/array/coreblas/coreblas_gemm.jl @@ -0,0 +1,21 @@ +using libblastrampoline_jll +using LinearAlgebra +using libcoreblas_jll + +for (gemm, T) in + ((:coreblas_dgemm, Float64), + (:coreblas_sgemm, Float32), + (:coreblas_cgemm, ComplexF32), + (:coreblas_zgemm, ComplexF64)) + @eval begin + function coreblas_gemm!(transa::Int64, transb::Int64, + alpha::$T, A::AbstractMatrix{$T}, B::AbstractMatrix{$T}, beta::$T, C::AbstractMatrix{$T}) + m, k = size(A) + k, n = size(B) + ccall(($gemm, "libcoreblas.so"), Cvoid, + (Int64, Int64, Int64, Int64, Int64, $T, Ptr{$T}, Int64, Ptr{$T}, Int64, + $T, Ptr{$T}, Int64), + transa, transb, m, n, k, alpha, A, m, B, k, beta, C, m) + end + end +end diff --git a/src/array/coreblas/coreblas_geqrt.jl b/src/array/coreblas/coreblas_geqrt.jl new file mode 100644 index 000000000..99f86d442 --- /dev/null +++ b/src/array/coreblas/coreblas_geqrt.jl @@ -0,0 +1,32 @@ +for (geqrt, T) in + ((:coreblas_dgeqrt, Float64), + (:coreblas_sgeqrt, Float32), + (:coreblas_cgeqrt, ComplexF32), + (:coreblas_zgeqrt, ComplexF64)) + @eval begin + function coreblas_geqrt!(A::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}) + require_one_based_indexing(A, Tau) + chkstride1(A) + m, n = size(A) + ib, nb = size(Tau) + lda = max(1, stride(A,2)) + ldt = max(1, stride(Tau,2)) + work = Vector{$T}(undef, (ib)*n) + ttau = Vector{$T}(undef, n) + + err = ccall(($(QuoteNode(geqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Ptr{$T}), + m, n, ib, + A, lda, + Tau, ldt, + ttau, work) + if err != 0 + throw(ArgumentError("coreblas_geqrt failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_ormqr.jl b/src/array/coreblas/coreblas_ormqr.jl new file mode 100644 index 000000000..5bcb7b1bb --- /dev/null +++ b/src/array/coreblas/coreblas_ormqr.jl @@ -0,0 +1,45 @@ +for (geormqr, T) in + ((:coreblas_dormqr, Float64), + (:coreblas_sormqr, Float32), + (:coreblas_zunmqr, ComplexF64), + (:coreblas_cunmqr, ComplexF32)) + @eval begin + function coreblas_ormqr!(side::Char, trans::Char, A::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}, C::AbstractMatrix{$T}) + + m, n = size(C) + ib, nb = size(Tau) + k = nb + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + sidenum = side == 'L' ? 141 : 142 + + lda = max(1, stride(A,2)) + ldt = max(1, stride(Tau,2)) + ldc = max(1, stride(C,2)) + ldwork = side == 'L' ? n : m + work = Vector{$T}(undef, ib*nb) + + + err = ccall(($(QuoteNode(geormqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64), + sidenum, transnum, + m, n, + k, ib, + A, lda, + Tau, ldt, + C, ldc, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_ormqr failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_tsmqr.jl b/src/array/coreblas/coreblas_tsmqr.jl new file mode 100644 index 000000000..1f647f7c1 --- /dev/null +++ b/src/array/coreblas/coreblas_tsmqr.jl @@ -0,0 +1,49 @@ +for (getsmqr, T) in + ((:coreblas_dtsmqr, Float64), + (:coreblas_ctsmqr, ComplexF32), + (:coreblas_ztsmqr, ComplexF64), + (:coreblas_stsmqr, Float32)) + @eval begin + function coreblas_tsmqr!(side::Char, trans::Char, A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, V::AbstractMatrix{$T}, Tau::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(Tau) + k = nb + + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + + sidenum = side == 'L' ? 141 : 142 + + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldv = max(1, stride(V,2)) + ldt = max(1, stride(Tau,2)) + ldwork = side == 'L' ? ib : m1 + work = Vector{$T}(undef, ib*nb) + + + err = ccall(($(QuoteNode(getsmqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64), + sidenum, transnum, + m1, n1, + m2, n2, + k, ib, + A1, lda1, + A2, lda2, + V, ldv, + Tau, ldt, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_tsmqr failed. Error number: $err")) + end + end + end +end diff --git a/src/array/coreblas/coreblas_tsqrt.jl b/src/array/coreblas/coreblas_tsqrt.jl new file mode 100644 index 000000000..e644465a3 --- /dev/null +++ b/src/array/coreblas/coreblas_tsqrt.jl @@ -0,0 +1,35 @@ + +for (getsqrt,T) in + ((:coreblas_dtsqrt, Float64), + (:coreblas_stsqrt, Float32), + (:coreblas_ctsqrt, ComplexF32), + (:coreblas_ztsqrt, ComplexF64)) + @eval begin + function coreblas_tsqrt!(A1::AbstractMatrix{$T}, A2::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}) + m = size(A2)[1] + n = size(A1)[2] + ib, nb = size(Tau) + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldt = max(1, stride(Tau,2)) + work = Vector{$T}(undef, (ib)*n) + ttau = Vector{$T}(undef, n) + + err = ccall(($(QuoteNode(getsqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), + m, n, ib, + A1, lda1, + A2, lda2, + Tau, ldt, + ttau, work) + if err != 0 + throw(ArgumentError("coreblas_tsqrt failed. Error number: $err")) + end + end + end +end + + diff --git a/src/array/coreblas/coreblas_ttmqr.jl b/src/array/coreblas/coreblas_ttmqr.jl new file mode 100644 index 000000000..32c0f4eb3 --- /dev/null +++ b/src/array/coreblas/coreblas_ttmqr.jl @@ -0,0 +1,51 @@ +using libcoreblas_jll +for (gettmqr, T) in + ((:coreblas_dttmqr, Float64), + (:coreblas_sttmqr, Float32), + (:coreblas_cttmqr, ComplexF32), + (:coreblas_zttmqr, ComplexF64)) + @eval begin + function coreblas_ttmqr!(side::Char, trans::Char, A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, V::AbstractMatrix{$T}, Tau::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(Tau) + k=nb + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + + sidenum = side == 'L' ? 141 : 142 + + ldv = max(1, stride(V,2)) + ldt = max(1, stride(Tau,2)) + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldwork = side == 'L' ? max(1,ib) : max(1,m1) + workdim = side == 'L' ? n1 : ib + work = Vector{$T}(undef, ldwork*workdim) + + err = ccall(($(QuoteNode(gettmqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64), + sidenum, transnum, + m1, n1, + m2, n2, + k, ib, + A1, lda1, + A2, lda2, + V, ldv, + Tau, ldt, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_ttmqr, failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_ttqrt.jl b/src/array/coreblas/coreblas_ttqrt.jl new file mode 100644 index 000000000..f90373517 --- /dev/null +++ b/src/array/coreblas/coreblas_ttqrt.jl @@ -0,0 +1,32 @@ + +for (gettqrt, T) in + ((:coreblas_dttqrt, Float64), + (:coreblas_sttqrt, Float32), + (:coreblas_cttqrt, ComplexF32), + (:coreblas_zttqrt, ComplexF64)) + @eval begin + function coreblas_ttqrt!(A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, triT::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(triT) + + lwork = nb + ib*nb + tau = Vector{$T}(undef, nb) + work = Vector{$T}(undef, (ib+1)*nb) + lda1 = max(1, stride(A1, 2)) + lda2 = max(1, stride(A2, 2)) + ldt = max(1, stride(triT, 2)) + + + err = ccall(($(QuoteNode(gettqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), + m1, n1, ib, A1, lda1, A2, lda2, triT, ldt, tau, work) + + if err != 0 + throw(ArgumentError("coreblas_ttqrt failed. Error number: $err")) + end + end + end +end diff --git a/src/array/qr.jl b/src/array/qr.jl new file mode 100644 index 000000000..cc825157e --- /dev/null +++ b/src/array/qr.jl @@ -0,0 +1,241 @@ +export geqrf!, porgqr!, pormqr!, cageqrf! +import LinearAlgebra: QRCompactWY, AdjointQ, BlasFloat, QRCompactWYQ, AbstractQ, StridedVecOrMat, I +import Base.:* +include("coreblas/coreblas_ormqr.jl") +include("coreblas/coreblas_ttqrt.jl") +include("coreblas/coreblas_ttmqr.jl") +include("coreblas/coreblas_geqrt.jl") +include("coreblas/coreblas_tsqrt.jl") +include("coreblas/coreblas_tsmqr.jl") + +(*)(Q::QRCompactWYQ{T, M}, b::Number) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b +(*)(b::Number, Q::QRCompactWYQ{T, M}) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b + +(*)(Q::AdjointQ{T, QRCompactWYQ{T, M, C}}, b::Number) where {T<:Number, M<:DMatrix{T}, C<:LowerTrapezoidal{T, M}} = DMatrix(Q) * b +(*)(b::Number, Q::AdjointQ{T, QRCompactWYQ{T, M, C}}) where {T<:Number, M<:DMatrix{T}, C<:LowerTrapezoidal{T, M}} = DMatrix(Q) * b + +LinearAlgebra.lmul!(B::QRCompactWYQ{T, M}, A::M) where {T, M<:DMatrix{T}} = pormqr!('L', 'N', B.factors, B.T, A) +function LinearAlgebra.lmul!(B::AdjointQ{T, <:QRCompactWYQ{T, M}}, A::M) where {T, M<:Dagger.DMatrix{T}} + trans = T <: Complex ? 'C' : 'T' + pormqr!('L', trans, B.Q.factors, B.Q.T, A) +end + +LinearAlgebra.rmul!(A::Dagger.DMatrix{T}, B::QRCompactWYQ{T, M}) where {T, M<:Dagger.DMatrix{T}} = pormqr!('R', 'N', B.factors, B.T, A) +function LinearAlgebra.rmul!(A::Dagger.DArray{T,2}, B::AdjointQ{T, <:QRCompactWYQ{T, M}}) where {T, M<:Dagger.DMatrix{T}} + trans = T <: Complex ? 'C' : 'T' + pormqr!('R', trans, B.Q.factors, B.Q.T, A) +end + +function Dagger.DMatrix(Q::QRCompactWYQ{T, <:Dagger.DArray{T, 2}}) where {T} + DQ = distribute(Matrix(I*one(T), size(Q.factors)[1], size(Q.factors)[1]), Q.factors.partitioning) + porgqr!('N', Q.factors, Q.T, DQ) + return DQ +end + +function Dagger.DMatrix(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DArray{T, 2}}}) where {T} + DQ = distribute(Matrix(I*one(T), size(AQ.Q.factors)[1], size(AQ.Q.factors)[1]), AQ.Q.factors.partitioning) + trans = T <: Complex ? 'C' : 'T' + porgqr!(trans, AQ.Q.factors, AQ.Q.T, DQ) + return DQ +end + +Base.collect(Q::QRCompactWYQ{T, <:Dagger.DArray{T, 2}}) where {T} = collect(Dagger.DMatrix(Q)) +Base.collect(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DArray{T, 2}}}) where {T} = collect(Dagger.DMatrix(AQ)) + +function pormqr!(side::Char, trans::Char, A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}, C::Dagger.DArray{T, 2}) where {T<:Number} + m, n = size(C) + Ac = A.chunks + Tc = Tm.data.chunks + Cc = C.chunks + + Amt, Ant = size(Ac) + Tmt, Tnt = size(Tc) + Cmt, Cnt = size(Cc) + minMT = min(Amt, Ant) + + Dagger.spawn_datadeps() do + if side == 'L' + if (trans == 'T' || trans == 'C') + for k in 1:minMT + for n in 1:Cnt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[k,n])) + end + for m in k+1:Cmt, n in 1:Cnt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[k, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + end + if trans == 'N' + for k in minMT:-1:1 + for m in Cmt:-1:k+1, n in 1:Cnt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[k, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in 1:Cnt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[k, n])) + end + end + end + else + if side == 'R' + if trans == 'T' || trans == 'C' + for k in minMT:-1:1 + for n in Cmt:-1:k+1, m in 1:Cmt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[m, k]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + for m in 1:Cmt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[m, k])) + end + end + end + if trans == 'N' + for k in 1:minMT + for m in 1:Cmt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[m, k])) + end + for n in k+1:Cmt, m in 1:Cmt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[m, k]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + end + end + end + end + end + return C +end + +function cageqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}; static::Bool=true, traversal::Symbol=:inorder, p::Int64=1) where {T<: Number} + if p == 1 + return geqrf!(A, Tm; static, traversal) + end + Ac = A.chunks + mt, nt = size(Ac) + @assert mt % p == 0 "Number of tiles must be divisible by the number of domains" + mtd = Int64(mt/p) + Tc = Tm.data.chunks + proot = 1 + nxtmt = mtd + trans = T <: Complex ? 'C' : 'T' + Dagger.spawn_datadeps(;static, traversal) do + for k in 1:min(mt, nt) + if k > nxtmt + proot += 1 + nxtmt += mtd + end + for pt in proot:p + ibeg = 1 + (pt-1) * mtd + if pt == proot + ibeg = k + end + Dagger.@spawn coreblas_geqrt!(InOut(Ac[ibeg, k]), Out(Tc[ibeg,k])) + for n in k+1:nt + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[ibeg, k]), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) + end + for m in ibeg+1:(pt * mtd) + Dagger.@spawn coreblas_tsqrt!(InOut(Ac[ibeg, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[ibeg, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + for m in 1:ceil(Int64, log2(p-proot+1)) + p1 = proot + p2 = p1 + 2^(m-1) + while p2 ≤ p + i1 = 1 + (p1-1) * mtd + i2 = 1 + (p2-1) * mtd + if p1 == proot + i1 = k + end + Dagger.@spawn coreblas_ttqrt!(InOut(Ac[i1, k]), InOut(Ac[i2, k]), Out(Tc[i2, k])) + for n in k+1:nt + Dagger.@spawn coreblas_ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), In(Ac[i2, k]), In(Tc[i2, k])) + end + p1 += 2^m + p2 += 2^m + end + end + end + end +end + +function geqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}; static::Bool=true, traversal::Symbol=:inorder) where {T<: Number} + Ac = A.chunks + mt, nt = size(Ac) + Tc = Tm.data.chunks + trans = T <: Complex ? 'C' : 'T' + + Ccopy = Dagger.DArray{T}(undef, A.partitioning, A.partitioning.blocksize[1], min(mt, nt) * A.partitioning.blocksize[2]) + Cc = Ccopy.chunks + Dagger.spawn_datadeps(;static, traversal) do + for k in 1:min(mt, nt) + Dagger.@spawn coreblas_geqrt!(InOut(Ac[k, k]), Out(Tc[k,k])) + # FIXME: This is a hack to avoid aliasing + Dagger.@spawn copyto!(InOut(Cc[1,k]), In(Ac[k, k])) + for n in k+1:nt + #FIXME: Change Cc[1,k] to upper triangular of Ac[k,k] + Dagger.@spawn coreblas_ormqr!('L', trans, In(Cc[1, k]), In(Tc[k,k]), InOut(Ac[k, n])) + end + for m in k+1:mt + Dagger.@spawn coreblas_tsqrt!(InOut(Ac[k, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[k, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + end +end + +function porgqr!(trans::Char, A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}, Q::Dagger.DArray{T, 2}; static::Bool=true, traversal::Symbol=:inorder) where {T<:Number} + Ac = A.chunks + Tc = Tm.data.chunks + Qc = Q.chunks + mt, nt = size(Ac) + qmt, qnt = size(Qc) + + Dagger.spawn_datadeps(;static, traversal) do + if trans == 'N' + for k in min(mt, nt):-1:1 + for m in qmt:-1:k + 1, n in k:qnt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in k:qnt + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + end + else + for k in 1:min(mt, nt) + for n in 1:k + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + for m in k+1:qmt, n in 1:qnt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + end + end +end + +function meas_ws(A::Dagger.DArray{T, 2}, ib::Int64) where {T<: Number} + mb, nb = A.partitioning.blocksize + m, n = size(A) + MT = (mod(m,nb)==0) ? floor(Int64, (m / mb)) : floor(Int64, (m / mb) + 1) + NT = (mod(n,nb)==0) ? floor(Int64,(n / nb)) : floor(Int64, (n / nb) + 1) * 2 + lm = ib * MT; + ln = nb * NT; + lm, ln +end + +function LinearAlgebra.qr!(A::Dagger.DArray{T, 2}; ib::Int64=1, p::Int64=1) where {T<:Number} + lm, ln = meas_ws(A, ib) + Ac = A.chunks + nb = A.partitioning.blocksize[2] + mt, nt = size(Ac) + st = nb * (nt - 1) + Tm = LowerTrapezoidal(zeros, Blocks(ib, nb), T, st, lm, ln) + geqrf!(A, Tm) + return QRCompactWY(A, Tm); +end + + diff --git a/test/array/linalg/qr.jl b/test/array/linalg/qr.jl new file mode 100644 index 000000000..2a347d164 --- /dev/null +++ b/test/array/linalg/qr.jl @@ -0,0 +1,36 @@ + @testset "Tile QR: $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + ## Square matrices + A = rand(T, 128, 128) + Q, R = qr(A) + DA = distribute(A, Blocks(32,32)) + DQ, DR = qr!(DA) + @test abs.(DQ) ≈ abs.(Q) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * abs.(DQ) ≈ abs.(Q) + @test I * abs.(DQ') ≈ abs.(Q') + + ## Rectangular matrices (block and element wise) + # Tall Element and Block + A = rand(T, 128, 64) + Q, R = qr(A) + DA = distribute(A, Blocks(32,32)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') + + # Wide Element and Block + A = rand(T, 64, 128) + Q, R = qr(A) + DA = distribute(A, Blocks(16,16)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') +end diff --git a/test/runtests.jl b/test/runtests.jl index 3f4e1b1ca..31aafc72a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ tests = [ ("Array - MapReduce", "array/mapreduce.jl"), ("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"), ("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"), + ("Array - LinearAlgebra - QR", "array/linalg/qr.jl"), ("Caching", "cache.jl"), ("Disk Caching", "diskcaching.jl"), ("File IO", "file-io.jl"), From 2022892736799836ceb6423728ecd5ebcaec7e79 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 16:11:29 -0300 Subject: [PATCH 09/34] DArray: disabling faulty views, fixing undefinit Dims arguments --- src/Dagger.jl | 4 +++- src/array/alloc.jl | 8 ++++---- src/array/darray.jl | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index 3c1d92eee..c3384978c 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -12,7 +12,9 @@ import Distributed import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remotecall, remotecall_wait, remotecall_fetch import LinearAlgebra -import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric +import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric, chkstride1 + +import libcoreblas_jll import UUIDs: UUID, uuid4 diff --git a/src/array/alloc.jl b/src/array/alloc.jl index 3457e9ac7..2bdb22238 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -31,7 +31,7 @@ end const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks} -function DArray{T}(::UndefInitializer, p::Blocks, dims) where {T} +function DArray{T}(::UndefInitializer, p::Blocks, dims::Dims) where {T} d = ArrayDomain(map(x->1:x, dims)) part = partition(p, d) f = function (_, T, sz) @@ -42,10 +42,10 @@ function DArray{T}(::UndefInitializer, p::Blocks, dims) where {T} end DArray(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) = DArray{Float64}(undef, p, dims) -DArray(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) = DArray{Float64}(undef, p, dims) +DArray(::UndefInitializer, p::BlocksOrAuto, dims::Dims) = DArray{Float64}(undef, p, dims) DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) where {T} = DArray{T}(undef, p, dims) -DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) where {T} = DArray{T}(undef, p, dims) -DArray{T}(::UndefInitializer, p::AutoBlocks, dims::Tuple) where {T} = DArray{T}(undef, auto_blocks(dims), dims) +DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Dims) where {T} = DArray{T}(undef, p, dims) +DArray{T}(::UndefInitializer, p::AutoBlocks, dims::Dims) where {T} = DArray{T}(undef, auto_blocks(dims), dims) function Base.rand(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) diff --git a/src/array/darray.jl b/src/array/darray.jl index 3a02e8143..1ca3331a8 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -334,11 +334,11 @@ Base.:(/)(x::DArray{T,N,B,F}, y::U) where {T<:Real,U<:Real,N,B,F} = A `view` of a `DArray` chunk returns a `DArray` of `Thunk`s. """ -function Base.view(c::DArray, d) +#=function Base.view(c::DArray, d) subchunks, subdomains = lookup_parts(c, chunks(c), domainchunks(c), d) d1 = alignfirst(d) DArray(eltype(c), d1, subdomains, subchunks, c.partitioning, c.concat) -end +end=# function group_indices(cumlength, idxs,at=1, acc=Any[]) at > length(idxs) && return acc From 36f67854527cd3c9f3bc23e6b1fc1ad602faf340 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 16:20:05 -0300 Subject: [PATCH 10/34] DArray: project.toml changes --- Project.toml | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index ebf03c388..e65e42783 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54" version = "0.18.11" [deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -23,6 +24,18 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +libcoreblas_jll = "339d4f0c-89b5-5ae2-b52c-218a0e582e15" + +[weakdeps] +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" + +[extensions] +GraphVizExt = "GraphViz" +GraphVizSimpleExt = "Colors" +PlotsExt = ["DataFrames", "Plots"] [compat] Colors = "0.12" @@ -43,19 +56,8 @@ TaskLocalValues = "0.1" TimespanLogging = "0.1" julia = "1.8" -[extensions] -GraphVizExt = "GraphViz" -GraphVizSimpleExt = "Colors" -PlotsExt = ["DataFrames", "Plots"] - [extras] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" - -[weakdeps] -Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" From 686b37a00c8f2c1e253cde4530009d9e7d2fbd5b Mon Sep 17 00:00:00 2001 From: fda-tome Date: Mon, 10 Jun 2024 15:51:43 -0300 Subject: [PATCH 11/34] DArray: coreblas changes and inclusion of libblastrampoline --- Project.toml | 1 + src/Dagger.jl | 1 + src/array/coreblas/coreblas_geqrt.jl | 2 +- src/array/coreblas/coreblas_ormqr.jl | 2 +- src/array/coreblas/coreblas_tsmqr.jl | 2 +- src/array/coreblas/coreblas_tsqrt.jl | 2 +- src/array/coreblas/coreblas_ttmqr.jl | 3 +-- src/array/coreblas/coreblas_ttqrt.jl | 2 +- 8 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index e65e42783..556697539 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +libblastrampoline_jll = "8e850b90-86db-534c-a0d3-1478176c7d93" libcoreblas_jll = "339d4f0c-89b5-5ae2-b52c-218a0e582e15" [weakdeps] diff --git a/src/Dagger.jl b/src/Dagger.jl index a77913ac7..670af469f 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -14,6 +14,7 @@ import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remot import LinearAlgebra import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric, chkstride1 import libcoreblas_jll +import libblastrampoline_jll import UUIDs: UUID, uuid4 diff --git a/src/array/coreblas/coreblas_geqrt.jl b/src/array/coreblas/coreblas_geqrt.jl index 99f86d442..6b3ac01f7 100644 --- a/src/array/coreblas/coreblas_geqrt.jl +++ b/src/array/coreblas/coreblas_geqrt.jl @@ -15,7 +15,7 @@ for (geqrt, T) in work = Vector{$T}(undef, (ib)*n) ttau = Vector{$T}(undef, n) - err = ccall(($(QuoteNode(geqrt)), libcoreblas), Int64, + err = ccall(($(QuoteNode(geqrt)), :libcoreblas), Int64, (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), diff --git a/src/array/coreblas/coreblas_ormqr.jl b/src/array/coreblas/coreblas_ormqr.jl index 5bcb7b1bb..cba7237e0 100644 --- a/src/array/coreblas/coreblas_ormqr.jl +++ b/src/array/coreblas/coreblas_ormqr.jl @@ -24,7 +24,7 @@ for (geormqr, T) in work = Vector{$T}(undef, ib*nb) - err = ccall(($(QuoteNode(geormqr)), libcoreblas), Int64, + err = ccall(($(QuoteNode(geormqr)), :libcoreblas), Int64, (Int64, Int64, Int64, Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, diff --git a/src/array/coreblas/coreblas_tsmqr.jl b/src/array/coreblas/coreblas_tsmqr.jl index 1f647f7c1..df92c555a 100644 --- a/src/array/coreblas/coreblas_tsmqr.jl +++ b/src/array/coreblas/coreblas_tsmqr.jl @@ -27,7 +27,7 @@ for (getsmqr, T) in work = Vector{$T}(undef, ib*nb) - err = ccall(($(QuoteNode(getsmqr)), libcoreblas), Int64, + err = ccall(($(QuoteNode(getsmqr)), :libcoreblas), Int64, (Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, diff --git a/src/array/coreblas/coreblas_tsqrt.jl b/src/array/coreblas/coreblas_tsqrt.jl index e644465a3..5b3fb4ea7 100644 --- a/src/array/coreblas/coreblas_tsqrt.jl +++ b/src/array/coreblas/coreblas_tsqrt.jl @@ -16,7 +16,7 @@ for (getsqrt,T) in work = Vector{$T}(undef, (ib)*n) ttau = Vector{$T}(undef, n) - err = ccall(($(QuoteNode(getsqrt)), libcoreblas), Int64, + err = ccall(($(QuoteNode(getsqrt)), :libcoreblas), Int64, (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), diff --git a/src/array/coreblas/coreblas_ttmqr.jl b/src/array/coreblas/coreblas_ttmqr.jl index 32c0f4eb3..6cd9002b5 100644 --- a/src/array/coreblas/coreblas_ttmqr.jl +++ b/src/array/coreblas/coreblas_ttmqr.jl @@ -1,4 +1,3 @@ -using libcoreblas_jll for (gettmqr, T) in ((:coreblas_dttmqr, Float64), (:coreblas_sttmqr, Float32), @@ -27,7 +26,7 @@ for (gettmqr, T) in workdim = side == 'L' ? n1 : ib work = Vector{$T}(undef, ldwork*workdim) - err = ccall(($(QuoteNode(gettmqr)), libcoreblas), Int64, + err = ccall(($(QuoteNode(gettmqr)), :libcoreblas), Int64, (Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, diff --git a/src/array/coreblas/coreblas_ttqrt.jl b/src/array/coreblas/coreblas_ttqrt.jl index f90373517..19465efbb 100644 --- a/src/array/coreblas/coreblas_ttqrt.jl +++ b/src/array/coreblas/coreblas_ttqrt.jl @@ -19,7 +19,7 @@ for (gettqrt, T) in ldt = max(1, stride(triT, 2)) - err = ccall(($(QuoteNode(gettqrt)), libcoreblas), Int64, + err = ccall(($(QuoteNode(gettqrt)), :libcoreblas), Int64, (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), m1, n1, ib, A1, lda1, A2, lda2, triT, ldt, tau, work) From e20fab50fbd2ab79cbd7674933496e83ebf9b096 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Tue, 11 Jun 2024 14:36:07 -0300 Subject: [PATCH 12/34] DArray: adding aliasing support --- src/array/qr.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/array/qr.jl b/src/array/qr.jl index cc825157e..28ccdc260 100644 --- a/src/array/qr.jl +++ b/src/array/qr.jl @@ -164,19 +164,14 @@ function geqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix Tc = Tm.data.chunks trans = T <: Complex ? 'C' : 'T' - Ccopy = Dagger.DArray{T}(undef, A.partitioning, A.partitioning.blocksize[1], min(mt, nt) * A.partitioning.blocksize[2]) - Cc = Ccopy.chunks Dagger.spawn_datadeps(;static, traversal) do for k in 1:min(mt, nt) Dagger.@spawn coreblas_geqrt!(InOut(Ac[k, k]), Out(Tc[k,k])) - # FIXME: This is a hack to avoid aliasing - Dagger.@spawn copyto!(InOut(Cc[1,k]), In(Ac[k, k])) for n in k+1:nt - #FIXME: Change Cc[1,k] to upper triangular of Ac[k,k] - Dagger.@spawn coreblas_ormqr!('L', trans, In(Cc[1, k]), In(Tc[k,k]), InOut(Ac[k, n])) + Dagger.@spawn coreblas_ormqr!('L', trans, Deps(Ac[k,k], In(LowerTriangular)), In(Tc[k,k]), InOut(Ac[k, n])) end for m in k+1:mt - Dagger.@spawn coreblas_tsqrt!(InOut(Ac[k, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + Dagger.@spawn coreblas_tsqrt!(Deps(Ac[k, k], InOut(UpperTriangular)), InOut(Ac[m, k]), Out(Tc[m,k])) for n in k+1:nt Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[k, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) end From 54cfefa089fdd067c7a9a257af901a3b8f3b1014 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Tue, 11 Jun 2024 14:51:38 -0300 Subject: [PATCH 13/34] DArray: adding aliasing support to CAQR --- src/array/qr.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/array/qr.jl b/src/array/qr.jl index 28ccdc260..249b9b560 100644 --- a/src/array/qr.jl +++ b/src/array/qr.jl @@ -128,10 +128,10 @@ function cageqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatr end Dagger.@spawn coreblas_geqrt!(InOut(Ac[ibeg, k]), Out(Tc[ibeg,k])) for n in k+1:nt - Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[ibeg, k]), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) + Dagger.@spawn coreblas_ormqr!('L', trans, Deps(Ac[ibeg, k], In(LowerTriangular)), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) end for m in ibeg+1:(pt * mtd) - Dagger.@spawn coreblas_tsqrt!(InOut(Ac[ibeg, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + Dagger.@spawn coreblas_tsqrt!(Deps(Ac[ibeg, k], InOut(UpperTriangular)), InOut(Ac[m, k]), Out(Tc[m,k])) for n in k+1:nt Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[ibeg, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) end @@ -146,9 +146,9 @@ function cageqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatr if p1 == proot i1 = k end - Dagger.@spawn coreblas_ttqrt!(InOut(Ac[i1, k]), InOut(Ac[i2, k]), Out(Tc[i2, k])) + Dagger.@spawn coreblas_ttqrt!(Deps(Ac[i1, k], InOut(UpperTriangular)), Deps(Ac[i2, k], InOut(UpperTriangular)), Out(Tc[i2, k])) for n in k+1:nt - Dagger.@spawn coreblas_ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), In(Ac[i2, k]), In(Tc[i2, k])) + Dagger.@spawn coreblas_ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), Deps(Ac[i2, k], In(UpperTriangular)), In(Tc[i2, k])) end p1 += 2^m p2 += 2^m From 1fa2ea4ea9793b60a1c54fc4091f906c82d53dce Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Tue, 9 Apr 2024 01:07:26 +0200 Subject: [PATCH 14/34] Allow for workers dying in the middle of cleanup --- src/sch/Sch.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index a6b2ce6b4..cfb680bbd 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -405,7 +405,20 @@ function cleanup_proc(state, p, log_sink) delete!(WORKER_MONITOR_CHANS[wid], state.uid) end end - remote_do(_cleanup_proc, wid, state.uid, log_sink) + + # If the worker process is still alive, clean it up + if wid in workers() + try + remotecall_wait(_cleanup_proc, wid, state.uid, log_sink) + catch ex + # We allow ProcessExitedException's, which means that the worker + # shutdown halfway through cleanup. + if !(ex isa ProcessExitedException) + rethrow() + end + end + end + timespan_finish(ctx, :cleanup_proc, (;worker=wid), nothing) end From 4d123ac8935eb18f66fbb6183979a0c946c7bd12 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Tue, 9 Apr 2024 13:06:52 +0200 Subject: [PATCH 15/34] Allow for dead workers in safepoint() --- src/sch/dynamic.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index 7c52bf748..df78a1dd1 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -33,9 +33,18 @@ function safepoint(state) if state.halt.set # Force dynamic thunks and listeners to terminate for (inp_chan,out_chan) in values(state.worker_chans) - close(inp_chan) - close(out_chan) + # Closing these channels will fail if the worker died, which we + # allow. + try + close(inp_chan) + close(out_chan) + catch ex + if !(ex isa ProcessExitedException) + rethrow() + end + end end + # Throw out of scheduler throw(SchedulerHaltedException()) end From dbfe428a7e86502cbbffad463c500705c5472659 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 10:43:46 -0500 Subject: [PATCH 16/34] parser: Fix expression escaping --- src/thunk.jl | 12 +++++------- test/thunk.jl | 4 ++++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index 177239979..4d94d9b96 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -306,7 +306,7 @@ generated thunks. macro par(exs...) opts = exs[1:end-1] ex = exs[end] - _par(ex; lazy=true, opts=opts) + return esc(_par(ex; lazy=true, opts=opts)) end """ @@ -348,7 +348,7 @@ also passes along any options in an `Options` struct. For example, macro spawn(exs...) opts = exs[1:end-1] ex = exs[end] - _par(ex; lazy=false, opts=opts) + return esc(_par(ex; lazy=false, opts=opts)) end struct ExpandedBroadcast{F} end @@ -372,17 +372,16 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) args = ex.args[2:end] kwargs = Expr(:parameters) end - opts = esc.(opts) args_ex = _par.(args; lazy=lazy, recur=false) kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false) if lazy - return :(Dagger.delayed($(esc(f)), $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...))) + return :(Dagger.delayed($f, $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...))) else - sync_var = esc(Base.sync_varname) + sync_var = Base.sync_varname @gensym result return quote let args = ($(args_ex...),) - $result = $spawn($(esc(f)), $Options(;$(opts...)), args...; $(kwargs_ex...)) + $result = $spawn($f, $Options(;$(opts...)), args...; $(kwargs_ex...)) if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->wait($result)))) end @@ -394,7 +393,6 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...) end end -_par(ex::Symbol; kwargs...) = esc(ex) _par(ex; kwargs...) = ex """ diff --git a/test/thunk.jl b/test/thunk.jl index db92ee340..42ed7125d 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -79,6 +79,10 @@ end @test fetch(@spawn A .+ B) ≈ A .+ B @test fetch(@spawn A .* B) ≈ A .* B end + @testset "inner macro" begin + A = rand(4) + @test fetch(@spawn sum(@view A[2:3])) ≈ sum(@view A[2:3]) + end @testset "waiting" begin a = @spawn sleep(1) @test !isready(a) From 8d29bd8444e1fbbeb1a559c9199ad43b0c226aec Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 10:44:30 -0500 Subject: [PATCH 17/34] tests: Instantiate before loading packages --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 3f4e1b1ca..98e61967d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) using Pkg Pkg.activate(@__DIR__) + Pkg.instantiate() using ArgParse s = ArgParseSettings(description = "Dagger Testsuite") From 515e731c5d5d7da27a0d37dd9e68a78f361f9749 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 12:06:10 -0500 Subject: [PATCH 18/34] parser: Support do-blocks --- src/thunk.jl | 40 +++++++++++++++++++++++++--------------- test/thunk.jl | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index 4d94d9b96..62fa31dec 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -363,25 +363,29 @@ function replace_broadcast(fn::Symbol) end function _par(ex::Expr; lazy=true, recur=true, opts=()) - if ex.head == :call && recur - f = replace_broadcast(ex.args[1]) - if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters) - args = ex.args[3:end] - kwargs = ex.args[2] - else - args = ex.args[2:end] - kwargs = Expr(:parameters) + body = nothing + if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) + f = replace_broadcast(f) + args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) + kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) + if !isempty(kwargs) + kwargs = only(kwargs).args + end + if body !== nothing + f = quote + ($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs + $body + end + end end - args_ex = _par.(args; lazy=lazy, recur=false) - kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false) if lazy - return :(Dagger.delayed($f, $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...))) + return :(Dagger.delayed($f, $Options(;$(opts...)))($(args...); $(kwargs...))) else sync_var = Base.sync_varname @gensym result return quote - let args = ($(args_ex...),) - $result = $spawn($f, $Options(;$(opts...)), args...; $(kwargs_ex...)) + let + $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->wait($result)))) end @@ -389,11 +393,17 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) end end end + elseif lazy + # Recurse into the expression + return Expr(ex.head, _par_inner.(ex.args, lazy=lazy, recur=recur, opts=opts)...) else - return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...) + throw(ArgumentError("Invalid Dagger task expression: $ex")) end end -_par(ex; kwargs...) = ex +_par(ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex")) + +_par_inner(ex; kwargs...) = ex +_par_inner(ex::Expr; kwargs...) = _par(ex; kwargs...) """ Dagger.spawn(f, args...; kwargs...) -> DTask diff --git a/test/thunk.jl b/test/thunk.jl index 42ed7125d..66ebc01f8 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -81,7 +81,39 @@ end end @testset "inner macro" begin A = rand(4) - @test fetch(@spawn sum(@view A[2:3])) ≈ sum(@view A[2:3]) + t = @spawn sum(@view A[2:3]) + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(@view A[2:3]) + end + @testset "do block" begin + A = rand(4) + + t = @spawn sum(A) do a + a + 1 + end + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(a->a+1, A) + + t = @spawn sum(A; dims=1) do a + a + 1 + end + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(a->a+1, A; dims=1) + + do_f = f -> f(42) + t = @spawn do_f() do x + x + 1 + end + @test t isa Dagger.DTask + @test fetch(t) == 43 + end + @testset "invalid expression" begin + @test_throws LoadError eval(:(@spawn 1)) + @test_throws LoadError eval(:(@spawn begin 1 end)) + @test_throws LoadError eval(:(@spawn begin + 1+1 + 1+1 + end)) end @testset "waiting" begin a = @spawn sleep(1) From 705266a6fa08d1fdde7c516138c64d9418406f31 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 12:51:24 -0500 Subject: [PATCH 19/34] parser: Support direct anonymous function calls --- src/thunk.jl | 17 +++++++++++++---- test/thunk.jl | 11 +++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index 62fa31dec..4f69bf1f8 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -363,8 +363,9 @@ function replace_broadcast(fn::Symbol) end function _par(ex::Expr; lazy=true, recur=true, opts=()) + f = nothing body = nothing - if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) + if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) f = replace_broadcast(f) args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) @@ -372,9 +373,17 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) kwargs = only(kwargs).args end if body !== nothing - f = quote - ($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs - $body + if f !== nothing + f = quote + ($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs + $body + end + end + else + f = quote + ($(args...); $(kwargs...))->begin + $body + end end end end diff --git a/test/thunk.jl b/test/thunk.jl index 66ebc01f8..7d18a292a 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -107,6 +107,17 @@ end @test t isa Dagger.DTask @test fetch(t) == 43 end + @testset "anonymous direct call" begin + A = rand(4) + + t = @spawn A->sum(A) + @test t isa Dagger.DTask + @test fetch(t) == sum(A) + + t = @spawn A->sum(A; dims=1) + @test t isa Dagger.DTask + @test fetch(t) == sum(A; dims=1) + end @testset "invalid expression" begin @test_throws LoadError eval(:(@spawn 1)) @test_throws LoadError eval(:(@spawn begin 1 end)) From 86f2f5a1e4f468b925a01b05940a148628d879c9 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 15:48:06 -0500 Subject: [PATCH 20/34] parser: Support getindex --- src/thunk.jl | 8 +++++++- test/thunk.jl | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/thunk.jl b/src/thunk.jl index 4f69bf1f8..017a69a9d 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -365,8 +365,14 @@ end function _par(ex::Expr; lazy=true, recur=true, opts=()) f = nothing body = nothing - if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) + arg1 = nothing + if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) || @capture(ex, arg1_[allargs__]) f = replace_broadcast(f) + if arg1 !== nothing + # Indexing (A[2,3]) + f = Base.getindex + pushfirst!(allargs, arg1) + end args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) if !isempty(kwargs) diff --git a/test/thunk.jl b/test/thunk.jl index 7d18a292a..95763e8b9 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -118,6 +118,23 @@ end @test t isa Dagger.DTask @test fetch(t) == sum(A; dims=1) end + @testset "getindex" begin + A = rand(4, 4) + + t = @spawn A[1, 2] + @test t isa Dagger.DTask + @test fetch(t) == A[1, 2] + + B = Dagger.@spawn rand(4, 4) + t = @spawn B[1, 2] + @test t isa Dagger.DTask + @test fetch(t) == fetch(B)[1, 2] + + R = Ref(42) + t = @spawn R[] + @test t isa Dagger.DTask + @test fetch(t) == 42 + end @testset "invalid expression" begin @test_throws LoadError eval(:(@spawn 1)) @test_throws LoadError eval(:(@spawn begin 1 end)) From b68ad4aa76742fed7cf8c5f15e95031ecba6d637 Mon Sep 17 00:00:00 2001 From: rabab53 Date: Thu, 20 Jun 2024 15:29:53 -0500 Subject: [PATCH 21/34] aliasing: Add optimized will_alias for views Co-authored-by: Julian P Samaroo --- src/memory-spaces.jl | 24 +++++++++++--------- test/datadeps.jl | 53 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 1a734850b..e3e90a124 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -136,7 +136,9 @@ aliasing(x::Transpose) = aliasing(parent(x)) aliasing(x::Adjoint) = aliasing(parent(x)) struct StridedAliasing{T,N,S} <: AbstractAliasing + base_ptr::RemotePtr{Cvoid,S} ptr::RemotePtr{Cvoid,S} + base_inds::NTuple{N,UnitRange{Int}} lengths::NTuple{N,Int} strides::NTuple{N,Int} end @@ -161,10 +163,12 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T}) where T +function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} if isbitstype(T) S = CPURAMMemorySpace - return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(x)), + return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), + RemotePtr{Cvoid}(pointer(x)), + parentindices(x), size(x), strides(parent(x))) else # FIXME: Also ContiguousAliasing of container @@ -172,21 +176,21 @@ function aliasing(x::SubArray{T}) where T return UnknownAliasing() end end -#= TODO: Fix and enable strided aliasing optimization function will_alias(x::StridedAliasing{T,N,S}, y::StridedAliasing{T,N,S}) where {T,N,S} - # TODO: Upgrade Contiguous/StridedAlising to same number of dims + if x.base_ptr != y.base_ptr + # FIXME: Conservatively incorrect via `unsafe_wrap` and friends + return false + end + for dim in 1:N - # FIXME: Adjust ptrs to common base - x_span = MemorySpan{S}(x.ptr, sizeof(T)*x.strides[dim]) - y_span = MemorySpan{S}(y.ptr, sizeof(T)*y.strides[dim]) - @show dim x_span y_span - if !will_alias(x_span, y_span) + if ((x.base_inds[dim].stop) < (y.base_inds[dim].start) || (y.base_inds[dim].stop) < (x.base_inds[dim].start)) return false end end + return true end -=# +# FIXME: Upgrade Contiguous/StridedAlising to same number of dims struct TriangularAliasing{T,S} <: AbstractAliasing ptr::RemotePtr{Cvoid,S} diff --git a/test/datadeps.jl b/test/datadeps.jl index e4b4c811f..271e2c667 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -337,6 +337,59 @@ function test_datadeps(;args_chunks::Bool, test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) + + # Additional aliasing tests + views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) + + A = wrap_chunk_thunk(identity, B) + + A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) + A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) + B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) + B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) + + A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) + A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) + B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) + B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) + + A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + + @test views_overlap(A_r1, A_r1) + @test views_overlap(B_r1, B_r1) + @test views_overlap(A_c1, A_c1) + @test views_overlap(B_c1, B_c1) + + @test views_overlap(A_r1, B_r1) + @test views_overlap(A_r2, B_r2) + @test views_overlap(A_c1, B_c1) + @test views_overlap(A_c2, B_c2) + + @test !views_overlap(A_r1, A_r2) + @test !views_overlap(B_r1, B_r2) + @test !views_overlap(A_c1, A_c2) + @test !views_overlap(B_c1, B_c2) + + @test views_overlap(A_r1, A_c1) + @test views_overlap(A_r1, B_c1) + @test views_overlap(A_r2, A_c2) + @test views_overlap(A_r2, B_c2) + + for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) + @test !views_overlap(A_r1, mid) + @test !views_overlap(B_r1, mid) + @test !views_overlap(A_c1, mid) + @test !views_overlap(B_c1, mid) + + @test views_overlap(A_r2, mid) + @test views_overlap(B_r2, mid) + @test views_overlap(A_c2, mid) + @test views_overlap(B_c2, mid) + end + + @test views_overlap(A_mid, A_mid) + @test views_overlap(A_mid, B_mid) end # FIXME: Deps From 9443ffbec1f0eba4cdce08dc74bf30fa98f9a4d3 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 May 2024 13:29:22 -0700 Subject: [PATCH 22/34] DArray: Make allocations dispatchable --- src/array/alloc.jl | 29 +++++++++++++++++++++-------- src/array/darray.jl | 10 +++------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/array/alloc.jl b/src/array/alloc.jl index 5b881bf0d..4a933905b 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -4,7 +4,8 @@ export partition mutable struct AllocateArray{T,N} <: ArrayOp{T,N} eltype::Type{T} - f::Function + f + want_index::Bool domain::ArrayDomain{N} domainchunks partitioning::AbstractBlocks @@ -23,9 +24,21 @@ function partition(p::AbstractBlocks, dom::ArrayDomain) map(_cumlength, map(length, indexes(dom)), p.blocksize)) end +function allocate_array(f, T, idx, sz) + new_f = allocate_array_func(thunk_processor(), f) + return new_f(idx, T, sz) +end +function allocate_array(f, T, sz) + new_f = allocate_array_func(thunk_processor(), f) + return new_f(T, sz) +end +allocate_array_func(::Processor, f) = f function stage(ctx, a::AllocateArray) - alloc(idx, sz) = a.f(idx, a.eltype, sz) - thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(a.domainchunks)] + if a.want_index + thunks = [Dagger.@spawn allocate_array(a.f, a.eltype, i, size(x)) for (i, x) in enumerate(a.domainchunks)] + else + thunks = [Dagger.@spawn allocate_array(a.f, a.eltype, size(x)) for (i, x) in enumerate(a.domainchunks)] + end return DArray(a.eltype, a.domain, a.domainchunks, thunks, a.partitioning) end @@ -33,7 +46,7 @@ const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks} function Base.rand(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) - a = AllocateArray(eltype, (_, x...) -> rand(x...), d, partition(p, d), p) + a = AllocateArray(eltype, rand, false, d, partition(p, d), p) return _to_darray(a) end Base.rand(p::BlocksOrAuto, T::Type, dims::Integer...) = rand(p, T, dims) @@ -45,7 +58,7 @@ Base.rand(::AutoBlocks, eltype::Type, dims::Dims) = function Base.randn(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) - a = AllocateArray(eltype, (_, x...) -> randn(x...), d, partition(p, d), p) + a = AllocateArray(eltype, randn, false, d, partition(p, d), p) return _to_darray(a) end Base.randn(p::BlocksOrAuto, T::Type, dims::Integer...) = randn(p, T, dims) @@ -57,7 +70,7 @@ Base.randn(::AutoBlocks, eltype::Type, dims::Dims) = function sprand(p::Blocks, eltype::Type, dims::Dims, sparsity::AbstractFloat) d = ArrayDomain(map(x->1:x, dims)) - a = AllocateArray(eltype, (_, T, _dims) -> sprand(T, _dims..., sparsity), d, partition(p, d), p) + a = AllocateArray(eltype, (T, _dims) -> sprand(T, _dims..., sparsity), false, d, partition(p, d), p) return _to_darray(a) end sprand(p::BlocksOrAuto, T::Type, dims_and_sparsity::Real...) = @@ -73,7 +86,7 @@ sprand(::AutoBlocks, eltype::Type, dims::Dims, sparsity::AbstractFloat) = function Base.ones(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) - a = AllocateArray(eltype, (_, x...) -> ones(x...), d, partition(p, d), p) + a = AllocateArray(eltype, ones, false, d, partition(p, d), p) return _to_darray(a) end Base.ones(p::BlocksOrAuto, T::Type, dims::Integer...) = ones(p, T, dims) @@ -85,7 +98,7 @@ Base.ones(::AutoBlocks, eltype::Type, dims::Dims) = function Base.zeros(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) - a = AllocateArray(eltype, (_, x...) -> zeros(x...), d, partition(p, d), p) + a = AllocateArray(eltype, zeros, false, d, partition(p, d), p) return _to_darray(a) end Base.zeros(p::BlocksOrAuto, T::Type, dims::Integer...) = zeros(p, T, dims) diff --git a/src/array/darray.jl b/src/array/darray.jl index d03065063..9e0bc4636 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -306,16 +306,12 @@ function Base.isequal(x::ArrayOp, y::ArrayOp) x === y end -function Base.similar(x::DArray{T,N}) where {T,N} - alloc(idx, sz) = Array{T,N}(undef, sz) - thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(x.subdomains)] - return DArray(T, x.domain, x.subdomains, thunks, x.partitioning, x.concat) -end - +struct AllocateUndef{S} end +(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = Array{S,N}(undef, dims) function Base.similar(A::DArray{T,N} where T, ::Type{S}, dims::Dims{N}) where {S,N} d = ArrayDomain(map(x->1:x, dims)) p = A.partitioning - a = AllocateArray(S, (_, _, x...) -> Array{S,N}(undef, x...), d, partition(p, d), p) + a = AllocateArray(S, AllocateUndef{S}(), false, d, partition(p, d), p) return _to_darray(a) end From c82220677f56f37f06ac85194fc2ac2fa05ccb5e Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 May 2024 13:29:35 -0700 Subject: [PATCH 23/34] DArray: Small matmul bugfix --- src/array/mul.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/array/mul.jl b/src/array/mul.jl index 4cdc7a525..44e5ab948 100644 --- a/src/array/mul.jl +++ b/src/array/mul.jl @@ -109,8 +109,8 @@ function gemm_dagger!( Bmt, Bnt = size(Bc) Cmt, Cnt = size(Cc) - alpha = _add.alpha - beta = _add.beta + alpha = T(_add.alpha) + beta = T(_add.beta) if Ant != Bmt throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but B has number of blocks ($Bmt,$Bnt)")) @@ -212,8 +212,8 @@ function syrk_dagger!( Amt, Ant = size(Ac) Cmt, Cnt = size(Cc) - alpha = _add.alpha - beta = _add.beta + alpha = T(_add.alpha) + beta = T(_add.beta) uplo = 'U' if Ant != Cmt @@ -233,7 +233,7 @@ function syrk_dagger!( Dagger.@spawn BLAS.herk!( uplo, trans, - alpha, + real(alpha), In(Ac[n, k]), mzone, InOut(Cc[n, n]), From c53fd0d41dcfb75ec4b012a8559a8382b5dd0340 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 09:00:24 -0300 Subject: [PATCH 24/34] Rebasing commit, solving conflicts --- src/Dagger.jl | 7 +- src/array/trapezoidal.jl | 142 +++++++++++++++++++++++++++++++++++++++ src/array/triangular.jl | 83 +++++++++++++++++++++++ 3 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/array/trapezoidal.jl create mode 100644 src/array/triangular.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 97f7dd44d..7211ad277 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -7,13 +7,14 @@ import SparseArrays: sprand, SparseMatrixCSC import MemPool import MemPool: DRef, FileRef, poolget, poolset -import Base: collect, reduce +import Base: collect, reduce, require_one_based_indexing import Distributed import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remotecall, remotecall_wait, remotecall_fetch import LinearAlgebra import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric + import UUIDs: UUID, uuid4 if !isdefined(Base, :ScopedValues) @@ -77,8 +78,12 @@ include("array/setindex.jl") include("array/matrix.jl") include("array/sparse_partition.jl") include("array/sort.jl") + +# Linear algebra include("array/linalg.jl") include("array/mul.jl") +include("array/trapezoidal.jl") +include("array/triangular.jl") include("array/cholesky.jl") # Visualization diff --git a/src/array/trapezoidal.jl b/src/array/trapezoidal.jl new file mode 100644 index 000000000..3497fa26a --- /dev/null +++ b/src/array/trapezoidal.jl @@ -0,0 +1,142 @@ +export LowerTrapezoidal, UnitLowerTrapezoidal, UpperTrapezoidal, UnitUpperTrapezoidal, trau!, trau, tral!, tral +import LinearAlgebra: triu!, tril!, triu, tril +abstract type AbstractTrapezoidal{T} <: AbstractMatrix{T} end + +# First loop through all methods that don't need special care for upper/lower and unit diagonal +for t in (:LowerTrapezoidal, :UnitLowerTrapezoidal, :UpperTrapezoidal, :UnitUpperTrapezoidal) + @eval begin + struct $t{T,S<:AbstractMatrix{T}} <: AbstractTrapezoidal{T} + data::S + + function $t{T,S}(data) where {T,S<:AbstractMatrix{T}} + Base.require_one_based_indexing(data) + new{T,S}(data) + end + end + $t(A::$t) = A + $t{T}(A::$t{T}) where {T} = A + $t(A::AbstractMatrix) = $t{eltype(A), typeof(A)}(A) + $t{T}(A::AbstractMatrix) where {T} = $t(convert(AbstractMatrix{T}, A)) + $t{T}(A::$t) where {T} = $t(convert(AbstractMatrix{T}, A.data)) + + AbstractMatrix{T}(A::$t) where {T} = $t{T}(A) + AbstractMatrix{T}(A::$t{T}) where {T} = copy(A) + + Base.size(A::$t) = size(A.data) + Base.axes(A::$t) = axes(A.data) + + Base.similar(A::$t, ::Type{T}) where {T} = $t(similar(parent(A), T)) + Base.similar(A::$t, ::Type{T}, dims::Dims{N}) where {T,N} = similar(parent(A), T, dims) + Base.parent(A::$t) = A.data + + Base.copy(A::$t) = $t(copy(A.data)) + + Base.real(A::$t{<:Real}) = A + Base.real(A::$t{<:Complex}) = (B = real(A.data); $t(B)) + end +end + +Base.getindex(A::UnitLowerTrapezoidal{T}, i::Integer, j::Integer) where {T} = + i > j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : ifelse(i == j, oneunit(T), zero(T)) +Base.getindex(A::LowerTrapezoidal, i::Integer, j::Integer) = +i >= j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : zero(eltype(A.data)) +Base.getindex(A::UnitUpperTrapezoidal{T}, i::Integer, j::Integer) where {T} = + i < j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : ifelse(i == j, oneunit(T), zero(T)) +Base.getindex(A::UpperTrapezoidal, i::Integer, j::Integer) = +i <= j ? ifelse(A.data[i,j] == nothing, zero(eltype(A.data)), A.data[i,j]) : zero(eltype(A.data)) + +function _DiagBuild(blockdom::Tuple, alloc::AbstractMatrix{T}, diag::Vector{Tuple{Int,Int}}, transform::Function) where {T} + diagind = findfirst(x-> x[1] in blockdom[1] && x[2] in blockdom[2], diag) + blockind = (diag[diagind][1] - first(blockdom[1]) + 1, diag[diagind][2] - first(blockdom[2]) + 1) + return Dagger.@spawn transform(alloc, blockind[2] - blockind[1]) +end + +function _GenericTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, wrap, k::Integer, dims) + d = ArrayDomain(map(x->1:x, dims)) + dc = partition(p, d) + if f isa UndefInitializer + f = (eltype, x...) -> Array{eltype}(undef, x...) + end + m, n = dims + if k < 0 + diag = [(i-k, i) for i in 1:min(m, n)] + else + diag = [(i, i+k) for i in 1:min(m, n)] + end + thunks = [] + alloc(sz) = f(eltype, sz) + transform = (wrap == LowerTrapezoidal) ? tril! : triu! + compar = (wrap == LowerTrapezoidal) ? (>) : (<) + for c in dc + sz = size(c) + if any(x -> x[1] in c.indexes[1] && x[2] in c.indexes[2], diag) + push!(thunks, _DiagBuild(c.indexes, alloc(sz), diag, transform)) + else + mt, nt = k<0 ? (first(c.indexes[1]), first(c.indexes[2])-k) : (first(c.indexes[1])+k, first(c.indexes[2])) + if compar(mt, nt) + push!(thunks, Dagger.@spawn alloc(sz)) + else + push!(thunks, Dagger.@spawn zeros(eltype, sz)) + end + end + end + thunks = reshape(thunks, size(dc)) + return wrap(Dagger.DArray(eltype, d, dc, thunks, p)) +end + +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, 0, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) +LowerTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, LowerTrapezoidal, k, dims) + +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, 0, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Integer...) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) +UpperTrapezoidal(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, k::Integer, dims::Tuple) = _GenericTrapezoidal(f, p, eltype, UpperTrapezoidal, k, dims) + +function _GenericTra!(A::Dagger.DArray{T, 2}, wrap::Function, k::Integer) where {T} + d = A.domain + dc = A.subdomains + Ac = A.chunks + m, n = size(A) + if k < 0 + diag = [(i-k, i) for i in 1:min(m, n)] + else + diag = [(i, i+k) for i in 1:min(m, n)] + end + compar = (wrap == tril!) ? (≤) : (≥) + for ind in CartesianIndices(dc) + sz = size(dc[ind]) + if any(x -> x[1] in dc[ind].indexes[1] && x[2] in dc[ind].indexes[2], diag) + Ac[ind] = _DiagBuild(dc[ind].indexes, fetch(Ac[ind]), diag, wrap) + else + mt, nt = k<0 ? (first(dc[ind].indexes[1]), first(dc[ind].indexes[2])-k) : (first(dc[ind].indexes[1])+k, first(dc[ind].indexes[2])) + if compar(mt, nt) + Ac[ind] = Dagger.@spawn zeros(T, sz) + end + end + end + return A +end + + + +trau!(A::Dagger.DArray{T,2}, k::Integer) where {T} = _GenericTra!(A, triu!, k) +trau!(A::Dagger.DArray{T,2}) where {T} = _GenericTra!(A, triu!, 0) +trau(A::Dagger.DArray{T,2}) where {T} = trau!(copy(A)) +trau(A::Dagger.DArray{T,2}, k::Integer) where {T} = trau!(copy(A), k) + +tral!(A::Dagger.DArray{T,2}, k::Integer) where {T} = _GenericTra!(A, tril!, k) +tral!(A::Dagger.DArray{T,2}) where {T} = _GenericTra!(A, tril!, 0) +tral(A::Dagger.DArray{T,2}) where {T} = tral!(copy(A)) +tral(A::Dagger.DArray{T,2}, k::Integer) where {T} = tral!(copy(A), k) + +#TODO: map, reduce, sum, mean, prod, reducedim, collect, distribute diff --git a/src/array/triangular.jl b/src/array/triangular.jl new file mode 100644 index 000000000..47613df2d --- /dev/null +++ b/src/array/triangular.jl @@ -0,0 +1,83 @@ +export LowerTriangular, UpperTriangular, tril, triu, tril!, triu! + +function _GenericTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, compar::Function, wrap, dims) + @assert dims[1] == dims[2] "matrix is not square: dimensions are $dims (try using trapezoidal)" + d = ArrayDomain(map(x->1:x, dims)) + dc = partition(p, d) + if f isa UndefInitializer + f = (eltype, x...) -> Array{eltype}(undef, x...) + end + diag = [(i,i) for i in 1:min(size(d)...)] + thunks = [] + alloc(sz) = f(eltype, sz) + transform = (wrap == LowerTrapezoidal) ? tril! : triu! + compar = (wrap == LowerTrapezoidal) ? (>) : (<) + for c in dc + sz = size(c) + if any(x -> x[1] in c.indexes[1] && x[2] in c.indexes[2], diag) + push!(thunks, _DiagBuild(c.indexes, alloc(sz), diag, transform)) + else + if compar(first(c.indexes[1]), first(c.indexes[2])) + push!(thunks, Dagger.@spawn alloc(sz)) + else + push!(thunks, Dagger.@spawn zeros(eltype, sz)) + end + end + end + thunks = reshape(thunks, size(dc)) + return wrap(Dagger.DArray(eltype, d, dc, thunks, p)) +end + +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) +LinearAlgebra.LowerTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTriangular(f, p, eltype, LowerTriangular, dims) + +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Integer...) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, dims::Tuple) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Integer...) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) +LinearAlgebra.UpperTriangular(f::Union{Function, UndefInitializer}, p::Blocks{2}, eltype::Type, dims::Tuple) = _GenericTriangular(f, p, eltype, UpperTriangular, dims) + +function _GenericTri!(A::Dagger.DArray{T, 2}, wrap) where {T} + LinearAlgebra.checksquare(A) + d = A.domain + dc = A.subdomains + Ac = A.chunks + diag = [(i,i) for i in 1:min(size(d)...)] + compar = (wrap == tril!) ? (>) : (<) + for ind in CartesianIndices(dc) + sz = size(dc[ind]) + if any(x -> x[1] in dc[ind].indexes[1] && x[2] in dc[ind].indexes[2], diag) + Ac[ind] = _DiagBuild(dc[ind].indexes, fetch(Ac[ind]), diag, wrap) + else + if compar(first(dc[ind].indexes[2]), first(dc[ind].indexes[1])) + Ac[ind] = Dagger.@spawn zeros(T, sz) + end + end + end + return A +end + +function LinearAlgebra.triu!(A::Dagger.DArray{T,2}) where {T} + if size(A, 1) != size(A, 2) + trau!(A) + else + _GenericTri!(A, triu!) + end +end +LinearAlgebra.triu(A::Dagger.DArray{T,2}) where {T} = triu!(copy(A)) + +function LinearAlgebra.tril!(A::Dagger.DArray{T,2}) where {T} + if size(A, 1) != size(A, 2) + tral!(A) + else + _GenericTri!(A, tril!) + end +end +LinearAlgebra.tril(A::Dagger.DArray{T,2}) where {T} = tril!(copy(A)) + +#TODO: matmul + + + + From 5d4a893cedf03fad11d0cdf520832af772c5c5bd Mon Sep 17 00:00:00 2001 From: fda-tome Date: Wed, 5 Jun 2024 17:04:19 -0300 Subject: [PATCH 25/34] DArray: UndefInitializer --- src/array/alloc.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/array/alloc.jl b/src/array/alloc.jl index 4a933905b..a63a06d3c 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -44,6 +44,22 @@ end const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks} +function DArray{T}(::UndefInitializer, p::Blocks, dims) where {T} + d = ArrayDomain(map(x->1:x, dims)) + part = partition(p, d) + f = function (_, T, sz) + Array{T, length(sz)}(undef, sz...) + end + a = AllocateArray(T, f, d, part, p) + return _to_darray(a) +end + +DArray(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) = DArray{Float64}(undef, p, dims) +DArray(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) = DArray{Float64}(undef, p, dims) +DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) where {T} = DArray{T}(undef, p, dims) +DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) where {T} = DArray{T}(undef, p, dims) +DArray{T}(::UndefInitializer, p::AutoBlocks, dims::Tuple) where {T} = DArray{T}(undef, auto_blocks(dims), dims) + function Base.rand(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) a = AllocateArray(eltype, rand, false, d, partition(p, d), p) From 53414146b2a43c10c228677da95518c7cc7a5ee9 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Wed, 5 Jun 2024 17:18:03 -0300 Subject: [PATCH 26/34] DArray: slicing bug fix --- src/array/darray.jl | 22 ++++++++++++++++++---- src/array/indexing.jl | 19 ++++++++++++++++++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/array/darray.jl b/src/array/darray.jl index 9e0bc4636..b2178060f 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -65,6 +65,10 @@ ArrayDomain((1:15), (1:80)) alignfirst(a::ArrayDomain) = ArrayDomain(map(r->1:length(r), indexes(a))) +alignfirst(a::CartesianIndices{N}) where N = + ArrayDomain(map(r->1:length(r), a.indices)) + + function size(a::ArrayDomain, dim) idxs = indexes(a) length(idxs) < dim ? 1 : length(idxs[dim]) @@ -366,7 +370,7 @@ function group_indices(cumlength, idxs::AbstractRange) end _cumsum(x::AbstractArray) = length(x) == 0 ? Int[] : cumsum(x) -function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}; slice::Bool=false) where N groups = map(group_indices, subdmns.cumlength, indexes(d)) sz = map(length, groups) pieces = Array{Any}(undef, sz) @@ -374,21 +378,31 @@ function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d: idx_and_dmn = map(getindex, groups, i.I) idx = map(x->x[1], idx_and_dmn) dmn = ArrayDomain(map(x->x[2], idx_and_dmn)) - pieces[i] = Dagger.@spawn getindex(ps[idx...], project(subdmns[idx...], dmn)) + if slice + pieces[i] = Dagger.@spawn getindex(ps[idx...], project(subdmns[idx...], dmn)) + else + pieces[i] = Dagger.@spawn view(ps[idx...], project(subdmns[idx...], dmn).indexes...) + end end out_cumlength = map(g->_cumsum(map(x->length(x[2]), g)), groups) out_dmn = DomainBlocks(ntuple(x->1,Val(N)), out_cumlength) return pieces, out_dmn end -function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}) where {N,S} +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}; slice::Bool=false) where {N,S} if S != 1 throw(BoundsError(A, d.indexes)) end inds = CartesianIndices(A)[d.indexes...] new_d = ntuple(i->first(inds).I[i]:last(inds).I[i], N) - return lookup_parts(A, ps, subdmns, ArrayDomain(new_d)) + return lookup_parts(A, ps, subdmns, ArrayDomain(new_d); slice) end +function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::CartesianIndices; slice::Bool=false) where N + return lookup_parts(A, ps, subdmns, ArrayDomain(d.indices); slice) +end + + + """ Base.fetch(c::DArray) diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbff..505e04f87 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -10,6 +10,16 @@ end GetIndex(input::ArrayOp, idx::Tuple) = GetIndex{eltype(input), ndims(input)}(input, idx) +function flatten(subdomains, subchunks, partitioning) + valdim = findfirst(j -> j != 1:1, subdomains[1].indexes) + flatc = [] + flats = Array{ArrayDomain{1, Tuple{UnitRange{Int64}}}}(undef, 0) + map(x -> push!(flats, ArrayDomain(x.indexes[valdim])), subdomains) + map(x -> push!(flatc, x), subchunks) + newb = Blocks(partitioning.blocksize[valdim]) + return flats, flatc, newb +end + function stage(ctx::Context, gidx::GetIndex) inp = stage(ctx, gidx.input) @@ -21,7 +31,14 @@ function stage(ctx::Context, gidx::GetIndex) end for i in 1:length(gidx.idx)] # Figure out output dimension - view(inp, ArrayDomain(idxs)) + d = ArrayDomain(idxs) + subchunks, subdomains = Dagger.lookup_parts(inp, chunks(inp), domainchunks(inp), d; slice = true) + d1 = alignfirst(d) + newb = inp.partitioning + if ndims(d1) != ndims(subdomains) + subdomains, subchunks, newb = flatten(subdomains, subchunks, inp.partitioning) + end + DArray(eltype(inp), d1, subdomains, subchunks, newb) end function size(x::GetIndex) From fb45ed484aa85e02dabaf845e0a1b3212683364f Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 08:30:24 -0300 Subject: [PATCH 27/34] DArray: Tile QR Implementation --- src/Dagger.jl | 1 + src/array/coreblas/coreblas_gemm.jl | 21 +++ src/array/coreblas/coreblas_geqrt.jl | 32 ++++ src/array/coreblas/coreblas_ormqr.jl | 45 +++++ src/array/coreblas/coreblas_tsmqr.jl | 49 ++++++ src/array/coreblas/coreblas_tsqrt.jl | 35 ++++ src/array/coreblas/coreblas_ttmqr.jl | 51 ++++++ src/array/coreblas/coreblas_ttqrt.jl | 32 ++++ src/array/qr.jl | 241 +++++++++++++++++++++++++++ test/array/linalg/qr.jl | 36 ++++ test/runtests.jl | 1 + 11 files changed, 544 insertions(+) create mode 100644 src/array/coreblas/coreblas_gemm.jl create mode 100644 src/array/coreblas/coreblas_geqrt.jl create mode 100644 src/array/coreblas/coreblas_ormqr.jl create mode 100644 src/array/coreblas/coreblas_tsmqr.jl create mode 100644 src/array/coreblas/coreblas_tsqrt.jl create mode 100644 src/array/coreblas/coreblas_ttmqr.jl create mode 100644 src/array/coreblas/coreblas_ttqrt.jl create mode 100644 src/array/qr.jl create mode 100644 test/array/linalg/qr.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 7211ad277..3c1d92eee 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -85,6 +85,7 @@ include("array/mul.jl") include("array/trapezoidal.jl") include("array/triangular.jl") include("array/cholesky.jl") +include("array/qr.jl") # Visualization include("visualization.jl") diff --git a/src/array/coreblas/coreblas_gemm.jl b/src/array/coreblas/coreblas_gemm.jl new file mode 100644 index 000000000..ad77be40d --- /dev/null +++ b/src/array/coreblas/coreblas_gemm.jl @@ -0,0 +1,21 @@ +using libblastrampoline_jll +using LinearAlgebra +using libcoreblas_jll + +for (gemm, T) in + ((:coreblas_dgemm, Float64), + (:coreblas_sgemm, Float32), + (:coreblas_cgemm, ComplexF32), + (:coreblas_zgemm, ComplexF64)) + @eval begin + function coreblas_gemm!(transa::Int64, transb::Int64, + alpha::$T, A::AbstractMatrix{$T}, B::AbstractMatrix{$T}, beta::$T, C::AbstractMatrix{$T}) + m, k = size(A) + k, n = size(B) + ccall(($gemm, "libcoreblas.so"), Cvoid, + (Int64, Int64, Int64, Int64, Int64, $T, Ptr{$T}, Int64, Ptr{$T}, Int64, + $T, Ptr{$T}, Int64), + transa, transb, m, n, k, alpha, A, m, B, k, beta, C, m) + end + end +end diff --git a/src/array/coreblas/coreblas_geqrt.jl b/src/array/coreblas/coreblas_geqrt.jl new file mode 100644 index 000000000..99f86d442 --- /dev/null +++ b/src/array/coreblas/coreblas_geqrt.jl @@ -0,0 +1,32 @@ +for (geqrt, T) in + ((:coreblas_dgeqrt, Float64), + (:coreblas_sgeqrt, Float32), + (:coreblas_cgeqrt, ComplexF32), + (:coreblas_zgeqrt, ComplexF64)) + @eval begin + function coreblas_geqrt!(A::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}) + require_one_based_indexing(A, Tau) + chkstride1(A) + m, n = size(A) + ib, nb = size(Tau) + lda = max(1, stride(A,2)) + ldt = max(1, stride(Tau,2)) + work = Vector{$T}(undef, (ib)*n) + ttau = Vector{$T}(undef, n) + + err = ccall(($(QuoteNode(geqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Ptr{$T}), + m, n, ib, + A, lda, + Tau, ldt, + ttau, work) + if err != 0 + throw(ArgumentError("coreblas_geqrt failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_ormqr.jl b/src/array/coreblas/coreblas_ormqr.jl new file mode 100644 index 000000000..5bcb7b1bb --- /dev/null +++ b/src/array/coreblas/coreblas_ormqr.jl @@ -0,0 +1,45 @@ +for (geormqr, T) in + ((:coreblas_dormqr, Float64), + (:coreblas_sormqr, Float32), + (:coreblas_zunmqr, ComplexF64), + (:coreblas_cunmqr, ComplexF32)) + @eval begin + function coreblas_ormqr!(side::Char, trans::Char, A::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}, C::AbstractMatrix{$T}) + + m, n = size(C) + ib, nb = size(Tau) + k = nb + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + sidenum = side == 'L' ? 141 : 142 + + lda = max(1, stride(A,2)) + ldt = max(1, stride(Tau,2)) + ldc = max(1, stride(C,2)) + ldwork = side == 'L' ? n : m + work = Vector{$T}(undef, ib*nb) + + + err = ccall(($(QuoteNode(geormqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64), + sidenum, transnum, + m, n, + k, ib, + A, lda, + Tau, ldt, + C, ldc, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_ormqr failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_tsmqr.jl b/src/array/coreblas/coreblas_tsmqr.jl new file mode 100644 index 000000000..1f647f7c1 --- /dev/null +++ b/src/array/coreblas/coreblas_tsmqr.jl @@ -0,0 +1,49 @@ +for (getsmqr, T) in + ((:coreblas_dtsmqr, Float64), + (:coreblas_ctsmqr, ComplexF32), + (:coreblas_ztsmqr, ComplexF64), + (:coreblas_stsmqr, Float32)) + @eval begin + function coreblas_tsmqr!(side::Char, trans::Char, A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, V::AbstractMatrix{$T}, Tau::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(Tau) + k = nb + + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + + sidenum = side == 'L' ? 141 : 142 + + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldv = max(1, stride(V,2)) + ldt = max(1, stride(Tau,2)) + ldwork = side == 'L' ? ib : m1 + work = Vector{$T}(undef, ib*nb) + + + err = ccall(($(QuoteNode(getsmqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64), + sidenum, transnum, + m1, n1, + m2, n2, + k, ib, + A1, lda1, + A2, lda2, + V, ldv, + Tau, ldt, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_tsmqr failed. Error number: $err")) + end + end + end +end diff --git a/src/array/coreblas/coreblas_tsqrt.jl b/src/array/coreblas/coreblas_tsqrt.jl new file mode 100644 index 000000000..e644465a3 --- /dev/null +++ b/src/array/coreblas/coreblas_tsqrt.jl @@ -0,0 +1,35 @@ + +for (getsqrt,T) in + ((:coreblas_dtsqrt, Float64), + (:coreblas_stsqrt, Float32), + (:coreblas_ctsqrt, ComplexF32), + (:coreblas_ztsqrt, ComplexF64)) + @eval begin + function coreblas_tsqrt!(A1::AbstractMatrix{$T}, A2::AbstractMatrix{$T}, + Tau::AbstractMatrix{$T}) + m = size(A2)[1] + n = size(A1)[2] + ib, nb = size(Tau) + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldt = max(1, stride(Tau,2)) + work = Vector{$T}(undef, (ib)*n) + ttau = Vector{$T}(undef, n) + + err = ccall(($(QuoteNode(getsqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), + m, n, ib, + A1, lda1, + A2, lda2, + Tau, ldt, + ttau, work) + if err != 0 + throw(ArgumentError("coreblas_tsqrt failed. Error number: $err")) + end + end + end +end + + diff --git a/src/array/coreblas/coreblas_ttmqr.jl b/src/array/coreblas/coreblas_ttmqr.jl new file mode 100644 index 000000000..32c0f4eb3 --- /dev/null +++ b/src/array/coreblas/coreblas_ttmqr.jl @@ -0,0 +1,51 @@ +using libcoreblas_jll +for (gettmqr, T) in + ((:coreblas_dttmqr, Float64), + (:coreblas_sttmqr, Float32), + (:coreblas_cttmqr, ComplexF32), + (:coreblas_zttmqr, ComplexF64)) + @eval begin + function coreblas_ttmqr!(side::Char, trans::Char, A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, V::AbstractMatrix{$T}, Tau::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(Tau) + k=nb + if $T <: Complex + transnum = trans == 'N' ? 111 : 113 + else + transnum = trans == 'N' ? 111 : 112 + end + + sidenum = side == 'L' ? 141 : 142 + + ldv = max(1, stride(V,2)) + ldt = max(1, stride(Tau,2)) + lda1 = max(1, stride(A1,2)) + lda2 = max(1, stride(A2,2)) + ldwork = side == 'L' ? max(1,ib) : max(1,m1) + workdim = side == 'L' ? n1 : ib + work = Vector{$T}(undef, ldwork*workdim) + + err = ccall(($(QuoteNode(gettmqr)), libcoreblas), Int64, + (Int64, Int64, Int64, Int64, + Int64, Int64, Int64, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64), + sidenum, transnum, + m1, n1, + m2, n2, + k, ib, + A1, lda1, + A2, lda2, + V, ldv, + Tau, ldt, + work, ldwork) + if err != 0 + throw(ArgumentError("coreblas_ttmqr, failed. Error number: $err")) + end + end + end +end + diff --git a/src/array/coreblas/coreblas_ttqrt.jl b/src/array/coreblas/coreblas_ttqrt.jl new file mode 100644 index 000000000..f90373517 --- /dev/null +++ b/src/array/coreblas/coreblas_ttqrt.jl @@ -0,0 +1,32 @@ + +for (gettqrt, T) in + ((:coreblas_dttqrt, Float64), + (:coreblas_sttqrt, Float32), + (:coreblas_cttqrt, ComplexF32), + (:coreblas_zttqrt, ComplexF64)) + @eval begin + function coreblas_ttqrt!(A1::AbstractMatrix{$T}, + A2::AbstractMatrix{$T}, triT::AbstractMatrix{$T}) + m1, n1 = size(A1) + m2, n2 = size(A2) + ib, nb = size(triT) + + lwork = nb + ib*nb + tau = Vector{$T}(undef, nb) + work = Vector{$T}(undef, (ib+1)*nb) + lda1 = max(1, stride(A1, 2)) + lda2 = max(1, stride(A2, 2)) + ldt = max(1, stride(triT, 2)) + + + err = ccall(($(QuoteNode(gettqrt)), libcoreblas), Int64, + (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, + Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), + m1, n1, ib, A1, lda1, A2, lda2, triT, ldt, tau, work) + + if err != 0 + throw(ArgumentError("coreblas_ttqrt failed. Error number: $err")) + end + end + end +end diff --git a/src/array/qr.jl b/src/array/qr.jl new file mode 100644 index 000000000..cc825157e --- /dev/null +++ b/src/array/qr.jl @@ -0,0 +1,241 @@ +export geqrf!, porgqr!, pormqr!, cageqrf! +import LinearAlgebra: QRCompactWY, AdjointQ, BlasFloat, QRCompactWYQ, AbstractQ, StridedVecOrMat, I +import Base.:* +include("coreblas/coreblas_ormqr.jl") +include("coreblas/coreblas_ttqrt.jl") +include("coreblas/coreblas_ttmqr.jl") +include("coreblas/coreblas_geqrt.jl") +include("coreblas/coreblas_tsqrt.jl") +include("coreblas/coreblas_tsmqr.jl") + +(*)(Q::QRCompactWYQ{T, M}, b::Number) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b +(*)(b::Number, Q::QRCompactWYQ{T, M}) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b + +(*)(Q::AdjointQ{T, QRCompactWYQ{T, M, C}}, b::Number) where {T<:Number, M<:DMatrix{T}, C<:LowerTrapezoidal{T, M}} = DMatrix(Q) * b +(*)(b::Number, Q::AdjointQ{T, QRCompactWYQ{T, M, C}}) where {T<:Number, M<:DMatrix{T}, C<:LowerTrapezoidal{T, M}} = DMatrix(Q) * b + +LinearAlgebra.lmul!(B::QRCompactWYQ{T, M}, A::M) where {T, M<:DMatrix{T}} = pormqr!('L', 'N', B.factors, B.T, A) +function LinearAlgebra.lmul!(B::AdjointQ{T, <:QRCompactWYQ{T, M}}, A::M) where {T, M<:Dagger.DMatrix{T}} + trans = T <: Complex ? 'C' : 'T' + pormqr!('L', trans, B.Q.factors, B.Q.T, A) +end + +LinearAlgebra.rmul!(A::Dagger.DMatrix{T}, B::QRCompactWYQ{T, M}) where {T, M<:Dagger.DMatrix{T}} = pormqr!('R', 'N', B.factors, B.T, A) +function LinearAlgebra.rmul!(A::Dagger.DArray{T,2}, B::AdjointQ{T, <:QRCompactWYQ{T, M}}) where {T, M<:Dagger.DMatrix{T}} + trans = T <: Complex ? 'C' : 'T' + pormqr!('R', trans, B.Q.factors, B.Q.T, A) +end + +function Dagger.DMatrix(Q::QRCompactWYQ{T, <:Dagger.DArray{T, 2}}) where {T} + DQ = distribute(Matrix(I*one(T), size(Q.factors)[1], size(Q.factors)[1]), Q.factors.partitioning) + porgqr!('N', Q.factors, Q.T, DQ) + return DQ +end + +function Dagger.DMatrix(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DArray{T, 2}}}) where {T} + DQ = distribute(Matrix(I*one(T), size(AQ.Q.factors)[1], size(AQ.Q.factors)[1]), AQ.Q.factors.partitioning) + trans = T <: Complex ? 'C' : 'T' + porgqr!(trans, AQ.Q.factors, AQ.Q.T, DQ) + return DQ +end + +Base.collect(Q::QRCompactWYQ{T, <:Dagger.DArray{T, 2}}) where {T} = collect(Dagger.DMatrix(Q)) +Base.collect(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:Dagger.DArray{T, 2}}}) where {T} = collect(Dagger.DMatrix(AQ)) + +function pormqr!(side::Char, trans::Char, A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}, C::Dagger.DArray{T, 2}) where {T<:Number} + m, n = size(C) + Ac = A.chunks + Tc = Tm.data.chunks + Cc = C.chunks + + Amt, Ant = size(Ac) + Tmt, Tnt = size(Tc) + Cmt, Cnt = size(Cc) + minMT = min(Amt, Ant) + + Dagger.spawn_datadeps() do + if side == 'L' + if (trans == 'T' || trans == 'C') + for k in 1:minMT + for n in 1:Cnt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[k,n])) + end + for m in k+1:Cmt, n in 1:Cnt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[k, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + end + if trans == 'N' + for k in minMT:-1:1 + for m in Cmt:-1:k+1, n in 1:Cnt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[k, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in 1:Cnt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[k, n])) + end + end + end + else + if side == 'R' + if trans == 'T' || trans == 'C' + for k in minMT:-1:1 + for n in Cmt:-1:k+1, m in 1:Cmt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[m, k]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + for m in 1:Cmt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[m, k])) + end + end + end + if trans == 'N' + for k in 1:minMT + for m in 1:Cmt + Dagger.@spawn coreblas_ormqr!(side, trans, In(Ac[k, k]), In(Tc[k, k]), InOut(Cc[m, k])) + end + for n in k+1:Cmt, m in 1:Cmt + Dagger.@spawn coreblas_tsmqr!(side, trans, InOut(Cc[m, k]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + end + end + end + end + end + return C +end + +function cageqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}; static::Bool=true, traversal::Symbol=:inorder, p::Int64=1) where {T<: Number} + if p == 1 + return geqrf!(A, Tm; static, traversal) + end + Ac = A.chunks + mt, nt = size(Ac) + @assert mt % p == 0 "Number of tiles must be divisible by the number of domains" + mtd = Int64(mt/p) + Tc = Tm.data.chunks + proot = 1 + nxtmt = mtd + trans = T <: Complex ? 'C' : 'T' + Dagger.spawn_datadeps(;static, traversal) do + for k in 1:min(mt, nt) + if k > nxtmt + proot += 1 + nxtmt += mtd + end + for pt in proot:p + ibeg = 1 + (pt-1) * mtd + if pt == proot + ibeg = k + end + Dagger.@spawn coreblas_geqrt!(InOut(Ac[ibeg, k]), Out(Tc[ibeg,k])) + for n in k+1:nt + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[ibeg, k]), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) + end + for m in ibeg+1:(pt * mtd) + Dagger.@spawn coreblas_tsqrt!(InOut(Ac[ibeg, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[ibeg, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + for m in 1:ceil(Int64, log2(p-proot+1)) + p1 = proot + p2 = p1 + 2^(m-1) + while p2 ≤ p + i1 = 1 + (p1-1) * mtd + i2 = 1 + (p2-1) * mtd + if p1 == proot + i1 = k + end + Dagger.@spawn coreblas_ttqrt!(InOut(Ac[i1, k]), InOut(Ac[i2, k]), Out(Tc[i2, k])) + for n in k+1:nt + Dagger.@spawn coreblas_ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), In(Ac[i2, k]), In(Tc[i2, k])) + end + p1 += 2^m + p2 += 2^m + end + end + end + end +end + +function geqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}; static::Bool=true, traversal::Symbol=:inorder) where {T<: Number} + Ac = A.chunks + mt, nt = size(Ac) + Tc = Tm.data.chunks + trans = T <: Complex ? 'C' : 'T' + + Ccopy = Dagger.DArray{T}(undef, A.partitioning, A.partitioning.blocksize[1], min(mt, nt) * A.partitioning.blocksize[2]) + Cc = Ccopy.chunks + Dagger.spawn_datadeps(;static, traversal) do + for k in 1:min(mt, nt) + Dagger.@spawn coreblas_geqrt!(InOut(Ac[k, k]), Out(Tc[k,k])) + # FIXME: This is a hack to avoid aliasing + Dagger.@spawn copyto!(InOut(Cc[1,k]), In(Ac[k, k])) + for n in k+1:nt + #FIXME: Change Cc[1,k] to upper triangular of Ac[k,k] + Dagger.@spawn coreblas_ormqr!('L', trans, In(Cc[1, k]), In(Tc[k,k]), InOut(Ac[k, n])) + end + for m in k+1:mt + Dagger.@spawn coreblas_tsqrt!(InOut(Ac[k, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[k, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + end +end + +function porgqr!(trans::Char, A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix{T}}, Q::Dagger.DArray{T, 2}; static::Bool=true, traversal::Symbol=:inorder) where {T<:Number} + Ac = A.chunks + Tc = Tm.data.chunks + Qc = Q.chunks + mt, nt = size(Ac) + qmt, qnt = size(Qc) + + Dagger.spawn_datadeps(;static, traversal) do + if trans == 'N' + for k in min(mt, nt):-1:1 + for m in qmt:-1:k + 1, n in k:qnt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in k:qnt + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + end + else + for k in 1:min(mt, nt) + for n in 1:k + Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + for m in k+1:qmt, n in 1:qnt + Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + end + end +end + +function meas_ws(A::Dagger.DArray{T, 2}, ib::Int64) where {T<: Number} + mb, nb = A.partitioning.blocksize + m, n = size(A) + MT = (mod(m,nb)==0) ? floor(Int64, (m / mb)) : floor(Int64, (m / mb) + 1) + NT = (mod(n,nb)==0) ? floor(Int64,(n / nb)) : floor(Int64, (n / nb) + 1) * 2 + lm = ib * MT; + ln = nb * NT; + lm, ln +end + +function LinearAlgebra.qr!(A::Dagger.DArray{T, 2}; ib::Int64=1, p::Int64=1) where {T<:Number} + lm, ln = meas_ws(A, ib) + Ac = A.chunks + nb = A.partitioning.blocksize[2] + mt, nt = size(Ac) + st = nb * (nt - 1) + Tm = LowerTrapezoidal(zeros, Blocks(ib, nb), T, st, lm, ln) + geqrf!(A, Tm) + return QRCompactWY(A, Tm); +end + + diff --git a/test/array/linalg/qr.jl b/test/array/linalg/qr.jl new file mode 100644 index 000000000..2a347d164 --- /dev/null +++ b/test/array/linalg/qr.jl @@ -0,0 +1,36 @@ + @testset "Tile QR: $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + ## Square matrices + A = rand(T, 128, 128) + Q, R = qr(A) + DA = distribute(A, Blocks(32,32)) + DQ, DR = qr!(DA) + @test abs.(DQ) ≈ abs.(Q) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * abs.(DQ) ≈ abs.(Q) + @test I * abs.(DQ') ≈ abs.(Q') + + ## Rectangular matrices (block and element wise) + # Tall Element and Block + A = rand(T, 128, 64) + Q, R = qr(A) + DA = distribute(A, Blocks(32,32)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') + + # Wide Element and Block + A = rand(T, 64, 128) + Q, R = qr(A) + DA = distribute(A, Blocks(16,16)) + DQ, DR = qr!(DA) + @test abs.(DR) ≈ abs.(R) + @test DQ * DR ≈ A + @test DQ' * DQ ≈ I + @test I * DQ ≈ collect(DQ) + @test I * DQ' ≈ collect(DQ') +end diff --git a/test/runtests.jl b/test/runtests.jl index 98e61967d..fbab200e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ tests = [ ("Array - MapReduce", "array/mapreduce.jl"), ("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"), ("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"), + ("Array - LinearAlgebra - QR", "array/linalg/qr.jl"), ("Caching", "cache.jl"), ("Disk Caching", "diskcaching.jl"), ("File IO", "file-io.jl"), From 48de8fcba7ead5ca1c4006465506005078a81bab Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 16:11:29 -0300 Subject: [PATCH 28/34] DArray: disabling faulty views, fixing undefinit Dims arguments --- src/Dagger.jl | 4 +++- src/array/alloc.jl | 8 ++++---- src/array/darray.jl | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index 3c1d92eee..c3384978c 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -12,7 +12,9 @@ import Distributed import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remotecall, remotecall_wait, remotecall_fetch import LinearAlgebra -import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric +import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric, chkstride1 + +import libcoreblas_jll import UUIDs: UUID, uuid4 diff --git a/src/array/alloc.jl b/src/array/alloc.jl index a63a06d3c..b3f534039 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -44,7 +44,7 @@ end const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks} -function DArray{T}(::UndefInitializer, p::Blocks, dims) where {T} +function DArray{T}(::UndefInitializer, p::Blocks, dims::Dims) where {T} d = ArrayDomain(map(x->1:x, dims)) part = partition(p, d) f = function (_, T, sz) @@ -55,10 +55,10 @@ function DArray{T}(::UndefInitializer, p::Blocks, dims) where {T} end DArray(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) = DArray{Float64}(undef, p, dims) -DArray(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) = DArray{Float64}(undef, p, dims) +DArray(::UndefInitializer, p::BlocksOrAuto, dims::Dims) = DArray{Float64}(undef, p, dims) DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Integer...) where {T} = DArray{T}(undef, p, dims) -DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Tuple) where {T} = DArray{T}(undef, p, dims) -DArray{T}(::UndefInitializer, p::AutoBlocks, dims::Tuple) where {T} = DArray{T}(undef, auto_blocks(dims), dims) +DArray{T}(::UndefInitializer, p::BlocksOrAuto, dims::Dims) where {T} = DArray{T}(undef, p, dims) +DArray{T}(::UndefInitializer, p::AutoBlocks, dims::Dims) where {T} = DArray{T}(undef, auto_blocks(dims), dims) function Base.rand(p::Blocks, eltype::Type, dims::Dims) d = ArrayDomain(map(x->1:x, dims)) diff --git a/src/array/darray.jl b/src/array/darray.jl index b2178060f..36a673c4b 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -331,11 +331,11 @@ Base.:(/)(x::DArray{T,N,B,F}, y::U) where {T<:Real,U<:Real,N,B,F} = A `view` of a `DArray` chunk returns a `DArray` of `Thunk`s. """ -function Base.view(c::DArray, d) +#=function Base.view(c::DArray, d) subchunks, subdomains = lookup_parts(c, chunks(c), domainchunks(c), d) d1 = alignfirst(d) DArray(eltype(c), d1, subdomains, subchunks, c.partitioning, c.concat) -end +end=# function group_indices(cumlength, idxs,at=1, acc=Any[]) at > length(idxs) && return acc From 10c3428366fca2189730ebeaf232e4d96ee2af7f Mon Sep 17 00:00:00 2001 From: fda-tome Date: Wed, 5 Jun 2024 16:55:05 -0300 Subject: [PATCH 29/34] DArray: Trapezoidal and Triangular wrappers --- src/Dagger.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index c3384978c..2cdc61498 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -14,9 +14,6 @@ import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remot import LinearAlgebra import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric, chkstride1 -import libcoreblas_jll - - import UUIDs: UUID, uuid4 if !isdefined(Base, :ScopedValues) From 407fe830e82f4d1c945b955de917dc95ae37c797 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Thu, 6 Jun 2024 16:20:05 -0300 Subject: [PATCH 30/34] DArray: project.toml changes --- Project.toml | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index ebf03c388..e65e42783 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54" version = "0.18.11" [deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -23,6 +24,18 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +libcoreblas_jll = "339d4f0c-89b5-5ae2-b52c-218a0e582e15" + +[weakdeps] +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" + +[extensions] +GraphVizExt = "GraphViz" +GraphVizSimpleExt = "Colors" +PlotsExt = ["DataFrames", "Plots"] [compat] Colors = "0.12" @@ -43,19 +56,8 @@ TaskLocalValues = "0.1" TimespanLogging = "0.1" julia = "1.8" -[extensions] -GraphVizExt = "GraphViz" -GraphVizSimpleExt = "Colors" -PlotsExt = ["DataFrames", "Plots"] - [extras] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" - -[weakdeps] -Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" From 9c52ef8f978e689e8c8d7508ce8b022146399562 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Mon, 10 Jun 2024 15:51:43 -0300 Subject: [PATCH 31/34] DArray: coreblas changes and inclusion of libblastrampoline --- Project.toml | 1 + src/array/coreblas/coreblas_geqrt.jl | 2 +- src/array/coreblas/coreblas_ormqr.jl | 2 +- src/array/coreblas/coreblas_tsmqr.jl | 2 +- src/array/coreblas/coreblas_tsqrt.jl | 2 +- src/array/coreblas/coreblas_ttmqr.jl | 3 +-- src/array/coreblas/coreblas_ttqrt.jl | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index e65e42783..556697539 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +libblastrampoline_jll = "8e850b90-86db-534c-a0d3-1478176c7d93" libcoreblas_jll = "339d4f0c-89b5-5ae2-b52c-218a0e582e15" [weakdeps] diff --git a/src/array/coreblas/coreblas_geqrt.jl b/src/array/coreblas/coreblas_geqrt.jl index 99f86d442..6b3ac01f7 100644 --- a/src/array/coreblas/coreblas_geqrt.jl +++ b/src/array/coreblas/coreblas_geqrt.jl @@ -15,7 +15,7 @@ for (geqrt, T) in work = Vector{$T}(undef, (ib)*n) ttau = Vector{$T}(undef, n) - err = ccall(($(QuoteNode(geqrt)), libcoreblas), Int64, + err = ccall(($(QuoteNode(geqrt)), :libcoreblas), Int64, (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), diff --git a/src/array/coreblas/coreblas_ormqr.jl b/src/array/coreblas/coreblas_ormqr.jl index 5bcb7b1bb..cba7237e0 100644 --- a/src/array/coreblas/coreblas_ormqr.jl +++ b/src/array/coreblas/coreblas_ormqr.jl @@ -24,7 +24,7 @@ for (geormqr, T) in work = Vector{$T}(undef, ib*nb) - err = ccall(($(QuoteNode(geormqr)), libcoreblas), Int64, + err = ccall(($(QuoteNode(geormqr)), :libcoreblas), Int64, (Int64, Int64, Int64, Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, diff --git a/src/array/coreblas/coreblas_tsmqr.jl b/src/array/coreblas/coreblas_tsmqr.jl index 1f647f7c1..df92c555a 100644 --- a/src/array/coreblas/coreblas_tsmqr.jl +++ b/src/array/coreblas/coreblas_tsmqr.jl @@ -27,7 +27,7 @@ for (getsmqr, T) in work = Vector{$T}(undef, ib*nb) - err = ccall(($(QuoteNode(getsmqr)), libcoreblas), Int64, + err = ccall(($(QuoteNode(getsmqr)), :libcoreblas), Int64, (Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, diff --git a/src/array/coreblas/coreblas_tsqrt.jl b/src/array/coreblas/coreblas_tsqrt.jl index e644465a3..5b3fb4ea7 100644 --- a/src/array/coreblas/coreblas_tsqrt.jl +++ b/src/array/coreblas/coreblas_tsqrt.jl @@ -16,7 +16,7 @@ for (getsqrt,T) in work = Vector{$T}(undef, (ib)*n) ttau = Vector{$T}(undef, n) - err = ccall(($(QuoteNode(getsqrt)), libcoreblas), Int64, + err = ccall(($(QuoteNode(getsqrt)), :libcoreblas), Int64, (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), diff --git a/src/array/coreblas/coreblas_ttmqr.jl b/src/array/coreblas/coreblas_ttmqr.jl index 32c0f4eb3..6cd9002b5 100644 --- a/src/array/coreblas/coreblas_ttmqr.jl +++ b/src/array/coreblas/coreblas_ttmqr.jl @@ -1,4 +1,3 @@ -using libcoreblas_jll for (gettmqr, T) in ((:coreblas_dttmqr, Float64), (:coreblas_sttmqr, Float32), @@ -27,7 +26,7 @@ for (gettmqr, T) in workdim = side == 'L' ? n1 : ib work = Vector{$T}(undef, ldwork*workdim) - err = ccall(($(QuoteNode(gettmqr)), libcoreblas), Int64, + err = ccall(($(QuoteNode(gettmqr)), :libcoreblas), Int64, (Int64, Int64, Int64, Int64, Int64, Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, diff --git a/src/array/coreblas/coreblas_ttqrt.jl b/src/array/coreblas/coreblas_ttqrt.jl index f90373517..19465efbb 100644 --- a/src/array/coreblas/coreblas_ttqrt.jl +++ b/src/array/coreblas/coreblas_ttqrt.jl @@ -19,7 +19,7 @@ for (gettqrt, T) in ldt = max(1, stride(triT, 2)) - err = ccall(($(QuoteNode(gettqrt)), libcoreblas), Int64, + err = ccall(($(QuoteNode(gettqrt)), :libcoreblas), Int64, (Int64, Int64, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Int64, Ptr{$T}, Ptr{$T}), m1, n1, ib, A1, lda1, A2, lda2, triT, ldt, tau, work) From 016fae3b0fd821054dc4530b127c71d7c1236535 Mon Sep 17 00:00:00 2001 From: fda-tome Date: Tue, 11 Jun 2024 14:36:07 -0300 Subject: [PATCH 32/34] DArray: adding aliasing support --- src/array/qr.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/array/qr.jl b/src/array/qr.jl index cc825157e..28ccdc260 100644 --- a/src/array/qr.jl +++ b/src/array/qr.jl @@ -164,19 +164,14 @@ function geqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatrix Tc = Tm.data.chunks trans = T <: Complex ? 'C' : 'T' - Ccopy = Dagger.DArray{T}(undef, A.partitioning, A.partitioning.blocksize[1], min(mt, nt) * A.partitioning.blocksize[2]) - Cc = Ccopy.chunks Dagger.spawn_datadeps(;static, traversal) do for k in 1:min(mt, nt) Dagger.@spawn coreblas_geqrt!(InOut(Ac[k, k]), Out(Tc[k,k])) - # FIXME: This is a hack to avoid aliasing - Dagger.@spawn copyto!(InOut(Cc[1,k]), In(Ac[k, k])) for n in k+1:nt - #FIXME: Change Cc[1,k] to upper triangular of Ac[k,k] - Dagger.@spawn coreblas_ormqr!('L', trans, In(Cc[1, k]), In(Tc[k,k]), InOut(Ac[k, n])) + Dagger.@spawn coreblas_ormqr!('L', trans, Deps(Ac[k,k], In(LowerTriangular)), In(Tc[k,k]), InOut(Ac[k, n])) end for m in k+1:mt - Dagger.@spawn coreblas_tsqrt!(InOut(Ac[k, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + Dagger.@spawn coreblas_tsqrt!(Deps(Ac[k, k], InOut(UpperTriangular)), InOut(Ac[m, k]), Out(Tc[m,k])) for n in k+1:nt Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[k, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) end From be5c810e8ba51ebcc3d31e635ea5587a5bb4108d Mon Sep 17 00:00:00 2001 From: fda-tome Date: Tue, 11 Jun 2024 14:51:38 -0300 Subject: [PATCH 33/34] DArray: adding aliasing support to CAQR --- src/array/qr.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/array/qr.jl b/src/array/qr.jl index 28ccdc260..249b9b560 100644 --- a/src/array/qr.jl +++ b/src/array/qr.jl @@ -128,10 +128,10 @@ function cageqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatr end Dagger.@spawn coreblas_geqrt!(InOut(Ac[ibeg, k]), Out(Tc[ibeg,k])) for n in k+1:nt - Dagger.@spawn coreblas_ormqr!('L', trans, In(Ac[ibeg, k]), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) + Dagger.@spawn coreblas_ormqr!('L', trans, Deps(Ac[ibeg, k], In(LowerTriangular)), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) end for m in ibeg+1:(pt * mtd) - Dagger.@spawn coreblas_tsqrt!(InOut(Ac[ibeg, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + Dagger.@spawn coreblas_tsqrt!(Deps(Ac[ibeg, k], InOut(UpperTriangular)), InOut(Ac[m, k]), Out(Tc[m,k])) for n in k+1:nt Dagger.@spawn coreblas_tsmqr!('L', trans, InOut(Ac[ibeg, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) end @@ -146,9 +146,9 @@ function cageqrf!(A::Dagger.DArray{T, 2}, Tm::LowerTrapezoidal{T, <:Dagger.DMatr if p1 == proot i1 = k end - Dagger.@spawn coreblas_ttqrt!(InOut(Ac[i1, k]), InOut(Ac[i2, k]), Out(Tc[i2, k])) + Dagger.@spawn coreblas_ttqrt!(Deps(Ac[i1, k], InOut(UpperTriangular)), Deps(Ac[i2, k], InOut(UpperTriangular)), Out(Tc[i2, k])) for n in k+1:nt - Dagger.@spawn coreblas_ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), In(Ac[i2, k]), In(Tc[i2, k])) + Dagger.@spawn coreblas_ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), Deps(Ac[i2, k], In(UpperTriangular)), In(Tc[i2, k])) end p1 += 2^m p2 += 2^m From 2b74fbe8bcce8d3800ae6c52db4c4b4410900cfb Mon Sep 17 00:00:00 2001 From: fda-tome Date: Mon, 24 Jun 2024 16:06:22 -0300 Subject: [PATCH 34/34] DArray: adequating undefinit to the new AllocateArray --- src/array/alloc.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array/alloc.jl b/src/array/alloc.jl index b3f534039..4155f0c6f 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -47,10 +47,10 @@ const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks} function DArray{T}(::UndefInitializer, p::Blocks, dims::Dims) where {T} d = ArrayDomain(map(x->1:x, dims)) part = partition(p, d) - f = function (_, T, sz) + f = function (T, sz) Array{T, length(sz)}(undef, sz...) end - a = AllocateArray(T, f, d, part, p) + a = AllocateArray(T, f, false, d, part, p) return _to_darray(a) end