Skip to content

Commit 7e58f28

Browse files
authored
[CUSPARSE] Add more tests (#1668)
1 parent 8951764 commit 7e58f28

File tree

6 files changed

+211
-282
lines changed

6 files changed

+211
-282
lines changed

lib/cusparse/array.jl

+1
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ Base.similar(Mat::CuSparseMatrixCOO, T::Type) = CuSparseMatrixCOO(copy(Mat.rowIn
215215

216216
Base.similar(Mat::CuSparseMatrixCSC, T::Type, N::Int, M::Int) = CuSparseMatrixCSC(CuVector{Int32}(undef, M+1), CuVector{Int32}(undef, nnz(Mat)), CuVector{T}(undef, nnz(Mat)), (N,M))
217217
Base.similar(Mat::CuSparseMatrixCSR, T::Type, N::Int, M::Int) = CuSparseMatrixCSR(CuVector{Int32}(undef, N+1), CuVector{Int32}(undef, nnz(Mat)), CuVector{T}(undef, nnz(Mat)), (N,M))
218+
Base.similar(Mat::CuSparseMatrixCOO, T::Type, N::Int, M::Int) = CuSparseMatrixCOO(CuVector{Int32}(undef, nnz(Mat)), CuVector{Int32}(undef, nnz(Mat)), CuVector{T}(undef, nnz(Mat)), (N,M))
218219

219220
## array interface
220221

lib/cusparse/conversions.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ function sort_coo(A::CuSparseMatrixCOO{Tv,Ti}, type::SparseChar='R') where {Tv,T
102102
sorted_rowInd = copy(A.rowInd)
103103
sorted_colInd = copy(A.colInd)
104104
function bufferSize()
105-
out = Ref{Csize_t}()
105+
# It seems that in some cases `out` is not updated
106+
# and we have the following error in the tests:
107+
# "Out of GPU memory trying to allocate 127.781 TiB".
108+
# We set 0 as default value to avoid it.
109+
out = Ref{Csize_t}(0)
106110
cusparseXcoosort_bufferSizeExt(handle(), m, n, nnz(A), A.rowInd, A.colInd, out)
107111
return out[]
108112
end

lib/cusparse/interfaces.jl

+88-186
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,31 @@ using LinearAlgebra
44
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal
55
export _spadjoint, _sptranspose
66

7+
function _spadjoint(A::CuSparseMatrixCSR)
8+
Aᴴ = CuSparseMatrixCSC(A.rowPtr, A.colVal, conj(A.nzVal), reverse(size(A)))
9+
CuSparseMatrixCSR(Aᴴ)
10+
end
11+
function _sptranspose(A::CuSparseMatrixCSR)
12+
Aᵀ = CuSparseMatrixCSC(A.rowPtr, A.colVal, A.nzVal, reverse(size(A)))
13+
CuSparseMatrixCSR(Aᵀ)
14+
end
15+
function _spadjoint(A::CuSparseMatrixCSC)
16+
Aᴴ = CuSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A)))
17+
CuSparseMatrixCSC(Aᴴ)
18+
end
19+
function _sptranspose(A::CuSparseMatrixCSC)
20+
Aᵀ = CuSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
21+
CuSparseMatrixCSC(Aᵀ)
22+
end
23+
function _spadjoint(A::CuSparseMatrixCOO)
24+
# we use sparse instead of CuSparseMatrixCOO because we want to sort the matrix.
25+
sparse(A.colInd, A.rowInd, conj(A.nzVal), reverse(size(A))..., fmt = :coo)
26+
end
27+
function _sptranspose(A::CuSparseMatrixCOO)
28+
# we use sparse instead of CuSparseMatrixCOO because we want to sort the matrix.
29+
sparse(A.colInd, A.rowInd, A.nzVal, reverse(size(A))..., fmt = :coo)
30+
end
31+
732
function mv_wrapper(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::DenseCuVector{T},
833
beta::Number, Y::CuVector{T}) where {T}
934
mv!(transa, alpha, A, X, beta, Y, 'O')
@@ -34,11 +59,15 @@ LinearAlgebra.dot(x::DenseCuVector{T}, y::CuSparseVector{T}) where {T <: BlasCom
3459

3560
tag_wrappers = ((identity, identity),
3661
(T -> :(HermOrSym{T, <:$T}), A -> :(parent($A))))
37-
op_wrappers = (
38-
(identity, T -> 'N', identity),
39-
(T -> :(Transpose{<:T, <:$T}), T -> 'T', A -> :(parent($A))),
40-
(T -> :(Adjoint{<:T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A)))
41-
)
62+
63+
adjtrans_wrappers = ((identity, identity),
64+
(M -> :(Transpose{T, <:$M}), M -> :(_sptranspose(parent($M)))),
65+
(M -> :(Adjoint{T, <:$M}), M -> :(_spadjoint(parent($M)))))
66+
67+
op_wrappers = ((identity, T -> 'N', identity),
68+
(T -> :(Transpose{<:T, <:$T}), T -> 'T', A -> :(parent($A))),
69+
(T -> :(Adjoint{<:T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))
70+
4271
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
4372
TypeA = wrapa(taga(:(CuSparseMatrix{T})))
4473

@@ -113,210 +142,83 @@ end
113142

114143
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
115144
@eval begin
116-
function Base.:(*)(A::$SparseMatrixType{T}, B::$SparseMatrixType{T}) where {T <: BlasFloat}
145+
function LinearAlgebra.:(*)(A::$SparseMatrixType{T}, B::$SparseMatrixType{T}) where {T <: BlasFloat}
117146
CUSPARSE.version() < v"11.1.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
118147
gemm('N', 'N', one(T), A, B, 'O')
119148
end
120-
end
121-
end
122-
123-
for op in (:(+), :(-))
124-
@eval begin
125-
Base.$op(A::CuSparseVector{T}, B::CuSparseVector{T}) where {T <: BlasFloat} = axpby(one(T), A, $(op)(one(T)), B, 'O')
126-
127-
Base.$op(A::Union{CuSparseMatrixCOO{T}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}},
128-
B::Union{CuSparseMatrixCOO{T}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}}) where {T} =
129-
CuSparseMatrixCOO($(op)(CuSparseMatrixCSR(A), CuSparseMatrixCSR(B)))
130-
end
131-
132-
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
133-
@eval begin
134-
Base.$op(A::$SparseMatrixType{T}, B::$SparseMatrixType{T}) where {T <: BlasFloat} = geam(one(T), A, $(op)(one(T)), B, 'O')
135-
136-
Base.$op(A::$SparseMatrixType{T}, B::Adjoint{T,<:$SparseMatrixType}) where {T <: BlasFloat} = geam(one(T), A, $(op)(one(T)), _spadjoint(parent(B)), 'O')
137-
Base.$op(A::Adjoint{T,<:$SparseMatrixType}, B::$SparseMatrixType{T}) where {T <: BlasFloat} = geam(one(T), _spadjoint(parent(A)), $(op)(one(T)), B, 'O')
138-
Base.$op(A::Adjoint{T,<:$SparseMatrixType}, B::Adjoint{T,<:$SparseMatrixType}) where {T <: BlasFloat} = geam(one(T), _spadjoint(parent(A)), $(op)(one(T)), _spadjoint(parent(B)), 'O')
139-
140-
Base.$op(A::$SparseMatrixType{T}, B::Transpose{T,<:$SparseMatrixType}) where {T <: BlasFloat} = geam(one(T), A, $(op)(one(T)), _sptranspose(parent(B)), 'O')
141-
Base.$op(A::Transpose{T,<:$SparseMatrixType}, B::$SparseMatrixType{T}) where {T <: BlasFloat} = geam(one(T), _sptranspose(parent(A)), $(op)(one(T)), B, 'O')
142-
Base.$op(A::Transpose{T,<:$SparseMatrixType}, B::Transpose{T,<:$SparseMatrixType}) where {T <: BlasFloat} = geam(one(T), _sptranspose(parent(A)), $(op)(one(T)), _sptranspose(parent(B)), 'O')
143-
end
144-
end
145-
146-
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCOO, :CuSparseMatrixBSR)
147-
@eval begin
148-
function Base.$op(A::CuSparseMatrixCSR{T}, B::$SparseMatrixType{T}) where {T}
149-
csrB = CuSparseMatrixCSR(B)
150-
return geam(one(T), A, $(op)(one(T)), csrB, 'O')
151-
end
152-
function Base.$op(A::$SparseMatrixType{T}, B::CuSparseMatrixCSR{T}) where {T}
153-
csrA = CuSparseMatrixCSR(A)
154-
return geam(one(T), csrA, $(op)(one(T)), B, 'O')
155-
end
156-
function Base.$op(A::Transpose{T, CuSparseMatrixCSR{T}}, B::$SparseMatrixType{T}) where {T}
157-
csrB = CuSparseMatrixCSR(B)
158-
return geam(one(T), _sptranspose(parent(A)), $(op)(one(T)), csrB, 'O')
159-
end
160-
function Base.$op(A::$SparseMatrixType{T}, B::Transpose{T, CuSparseMatrixCSR{T}}) where {T}
161-
csrA = CuSparseMatrixCSR(A)
162-
return geam(one(T), csrA, $(op)(one(T)), _sptranspose(parent(B)), 'O')
163-
end
164-
function Base.$op(A::Adjoint{T, CuSparseMatrixCSR{T}}, B::$SparseMatrixType{T}) where {T}
165-
csrB = CuSparseMatrixCSR(B)
166-
return geam(one(T), _spadjoint(parent(A)), $(op)(one(T)), csrB, 'O')
167-
end
168-
function Base.$op(A::$SparseMatrixType{T}, B::Adjoint{T, CuSparseMatrixCSR{T}}) where {T}
169-
csrA = CuSparseMatrixCSR(A)
170-
return geam(one(T), csrA, $(op)(one(T)), _spadjoint(parent(B)), 'O')
171-
end
172-
173-
function Base.$op(A::CuSparseMatrixCSR{T}, B::Transpose{T, $SparseMatrixType}) where {T}
174-
csrB = CuSparseMatrixCSR(_sptranspose(parent(B)))
175-
return geam(one(T), A, $(op)(one(T)), csrB, 'O')
176-
end
177-
function Base.$op(A::Transpose{T, $SparseMatrixType}, B::CuSparseMatrixCSR{T}) where {T}
178-
csrA = CuSparseMatrixCSR(_sptranspose(parent(A)))
179-
return geam(one(T), csrA, $(op)(one(T)), B, 'O')
180-
end
181149

182-
function Base.$op(A::CuSparseMatrixCSR{T}, B::Adjoint{T, $SparseMatrixType}) where {T}
183-
csrB = CuSparseMatrixCSR(_spadjoint(parent(B)))
184-
return geam(one(T), A, $(op)(one(T)), csrB, 'O')
185-
end
186-
function Base.$op(A::Adjoint{T, $SparseMatrixType}, B::CuSparseMatrixCSR{T}) where {T}
187-
csrA = CuSparseMatrixCSR(_spadjoint(parent(A)))
188-
return geam(one(T), csrA, $(op)(one(T)), B, 'O')
189-
end
150+
function LinearAlgebra.mul!(C::$SparseMatrixType{T}, A::$SparseMatrixType{T}, B::$SparseMatrixType{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
151+
CUSPARSE.version() < v"11.1.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
152+
gemm!('N', 'N', alpha, A, B, beta, C, 'O')
190153
end
191154
end
192155
end
193156

194-
195-
function Base.reshape(A::CuSparseMatrixCOO{T,M}, dims::NTuple{N,Int}) where {T,N,M}
196-
nrows, ncols = size(A)
197-
flat_indices = nrows .* (A.colInd .- 1) .+ A.rowInd .- 1
198-
new_col, new_row = div.(flat_indices, dims[1]) .+ 1, rem.(flat_indices, dims[1]) .+ 1
199-
sparse(new_row, new_col, A.nzVal, dims[1], length(dims) == 1 ? 1 : dims[2], fmt = :coo)
157+
function LinearAlgebra.:(*)(A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}) where {T <: BlasFloat}
158+
CUSPARSE.version() < v"11.1.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
159+
A_csr = CuSparseMatrixCSR(A)
160+
B_csr = CuSparseMatrixCSR(B)
161+
CuSparseMatrixCOO(A_csr * B_csr)
200162
end
201163

202-
function LinearAlgebra.mul!(Y::CuSparseMatrixCSR{T,M}, A::CuSparseMatrixCSR{T,M},
203-
B::CuSparseMatrixCSR{T,M}, alpha::Number, beta::Number) where {T,M}
204-
CUSPARSE.version() < v"11.5.1" && throw(ErrorException("This operation is not
205-
supported by the current CUDA version."))
206-
gemm!('N', 'N', alpha, A, B, beta, Y, 'O')
207-
end
208-
LinearAlgebra.mul!(Y::CuSparseMatrixCSR{T,M}, A::CuSparseMatrixCSR{T,M},
209-
B::CuSparseMatrixCSR{T,M}) where {T,M} = mul!(Y, A, B, one(T), zero(T))
210-
211-
LinearAlgebra.mul!(Y::CuSparseMatrixCSR{T,M}, A::Transpose{T,<:CuSparseMatrixCSR},
212-
B::CuSparseMatrixCSR{T,M}) where {T,M} = mul!(Y, _sptranspose(parent(A)), B, one(T), zero(T))
213-
LinearAlgebra.mul!(Y::CuSparseMatrixCSR{T,M}, A::Transpose{T,<:CuSparseMatrixCSR},
214-
B::Transpose{T,<:CuSparseMatrixCSR}) where {T,M} = mul!(Y, _sptranspose(parent(A)), _sptranspose(parent(B)), one(T), zero(T))
215-
LinearAlgebra.mul!(Y::CuSparseMatrixCSR{T,M}, A::CuSparseMatrixCSR{T,M},
216-
B::Transpose{T,<:CuSparseMatrixCSR}) where {T,M} = mul!(Y, A, _sptranspose(parent(B)), one(t), zero(T))
217-
218-
LinearAlgebra.mul!(Y::CuSparseMatrixCSR{T,M}, A::Adjoint{T,<:CuSparseMatrixCSR},
219-
B::CuSparseMatrixCSR{T,M}) where {T,M} = mul!(Y, _spadjoint(parent(A)), B, one(T), zero(T))
220-
LinearAlgebra.mul!(Y::CuSparseMatrixCSR{T,M}, A::Adjoint{T,<:CuSparseMatrixCSR},
221-
B::Adjoint{T,<:CuSparseMatrixCSR}) where {T,M} = mul!(Y, _spadjoint(parent(A)), _spadjoint(parent(B)), one(T), zero(T))
222-
LinearAlgebra.mul!(Y::CuSparseMatrixCSR{T,M}, A::CuSparseMatrixCSR{T,M},
223-
B::Adjoint{T,<:CuSparseMatrixCSR}) where {T,M} = mul!(Y, A, _spadjoint(parent(B)), one(t), zero(T))
224-
225-
function LinearAlgebra.mul!(Y::CuSparseMatrixCOO{T,M}, A::Union{CuSparseMatrixCOO{T,M}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}},
226-
B::Union{CuSparseMatrixCOO{T,M}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}}, alpha::Number, beta::Number) where {T,M}
227-
228-
Y2 = CuSparseMatrixCSR(Y)
229-
A2 = CuSparseMatrixCSR(A)
230-
B2 = CuSparseMatrixCSR(B)
231-
mul!(Y2, A2, B2, alpha, beta)
232-
copyto!(Y, CuSparseMatrixCOO(Y2))
164+
function LinearAlgebra.mul!(C::CuSparseMatrixCOO{T}, A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
165+
CUSPARSE.version() < v"11.1.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
166+
A_csr = CuSparseMatrixCSR(A)
167+
B_csr = CuSparseMatrixCSR(B)
168+
C_csr = CuSparseMatrixCSR(C)
169+
mul!(C_csr, A_csr, B_csr, alpha, beta)
170+
C = CuSparseMatrixCOO(C_csr)
233171
end
234-
function LinearAlgebra.mul!(Y::CuSparseMatrixCSC{T,M}, A::Union{CuSparseMatrixCSC{T,M}, Transpose{T,<:CuSparseMatrixCSC}, Adjoint{T,<:CuSparseMatrixCSC}},
235-
B::Union{CuSparseMatrixCSC{T,M}, Transpose{T,<:CuSparseMatrixCSC}, Adjoint{T,<:CuSparseMatrixCSC}}, alpha::Number, beta::Number) where {T,M}
236-
237-
Y2 = CuSparseMatrixCSR(Y)
238-
A2 = CuSparseMatrixCSR(A)
239-
B2 = CuSparseMatrixCSR(B)
240-
mul!(Y2, A2, B2, alpha, beta)
241-
copyto!(Y, CuSparseMatrixCSC(Y2))
242-
end
243-
244-
LinearAlgebra.mul!(Y::CuSparseMatrixCOO{T,M}, A::Union{CuSparseMatrixCOO{T,M}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}},
245-
B::Union{CuSparseMatrixCOO{T,M}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}}) where {T,M} = mul!(Y, A, B, one(T), zero(T))
246-
LinearAlgebra.mul!(Y::CuSparseMatrixCSC{T,M}, A::Union{CuSparseMatrixCSC{T,M}, Transpose{T,<:CuSparseMatrixCSC}, Adjoint{T,<:CuSparseMatrixCSC}},
247-
B::Union{CuSparseMatrixCSC{T,M}, Transpose{T,<:CuSparseMatrixCSC}, Adjoint{T,<:CuSparseMatrixCSC}}) where {T,M} = mul!(Y, A, B, one(T), zero(T))
248172

249-
function LinearAlgebra.:(*)(A::CuSparseMatrixCSR{T,M}, B::CuSparseMatrixCSR{T,M}) where {T,M}
250-
CUSPARSE.version() < v"11.1.1" && throw(ErrorException("This operation is not
251-
supported by the current CUDA version."))
252-
gemm('N', 'N', one(T), A, B, 'O')
253-
end
254-
function LinearAlgebra.:(*)(A::CuSparseMatrixCSC{T,M}, B::CuSparseMatrixCSC{T,M}) where {T,M}
255-
A2 = CuSparseMatrixCSR(A)
256-
B2 = CuSparseMatrixCSR(B)
257-
CuSparseMatrixCSC(gemm('N', 'N', one(T), A2, B2, 'O'))
258-
end
259-
function LinearAlgebra.:(*)(A::CuSparseMatrixCOO{T,M}, B::CuSparseMatrixCOO{T,M}) where {T,M}
260-
A2 = CuSparseMatrixCSR(A)
261-
B2 = CuSparseMatrixCSR(B)
262-
CuSparseMatrixCOO(gemm('N', 'N', one(T), A2, B2, 'O'))
173+
for (wrapa, unwrapa) in adjtrans_wrappers, (wrapb, unwrapb) in adjtrans_wrappers
174+
for SparseMatrixType in (:(CuSparseMatrixCSC{T}), :(CuSparseMatrixCSR{T}), :(CuSparseMatrixCOO{T}))
175+
TypeA = wrapa(SparseMatrixType)
176+
TypeB = wrapb(SparseMatrixType)
177+
wrapa == identity && wrapb == identity && continue
178+
@eval begin
179+
LinearAlgebra.:(*)(A::$TypeA, B::$TypeB) where {T <: BlasFloat} = $(unwrapa(:A)) * $(unwrapb(:B))
180+
LinearAlgebra.mul!(C::$SparseMatrixType, A::$TypeA, B::$TypeB, alpha::Number, beta::Number) where {T <: BlasFloat} = mul!(C, $(unwrapa(:A)), $(unwrapb(:B)), alpha, beta)
181+
end
182+
end
263183
end
264184

265-
function SparseArrays.droptol!(A::CuSparseMatrixCOO{T,M}, tol::Real) where {T,M}
266-
mask = abs.(A.nzVal) .> tol
267-
rows = A.rowInd[mask]
268-
cols = A.colInd[mask]
269-
datas = A.nzVal[mask]
270-
B = sparse(rows, cols, datas, size(A)..., fmt = :coo)
271-
copyto!(A, B)
272-
end
185+
for op in (:(+), :(-))
186+
for (wrapa, unwrapa) in adjtrans_wrappers, (wrapb, unwrapb) in adjtrans_wrappers
187+
for SparseMatrixType in (:(CuSparseMatrixCSC{T}), :(CuSparseMatrixCSR{T}))
188+
TypeA = wrapa(SparseMatrixType)
189+
TypeB = wrapb(SparseMatrixType)
190+
@eval Base.$op(A::$TypeA, B::$TypeB) where {T <: BlasFloat} = geam(one(T), $(unwrapa(:A)), $(op)(one(T)), $(unwrapb(:B)), 'O')
191+
end
192+
end
273193

274-
for SparseMatrixType in [:CuSparseMatrixCSC, :CuSparseMatrixCSR, :CuSparseMatrixCOO]
275194
@eval begin
276-
if $SparseMatrixType in [CuSparseMatrixCSC, CuSparseMatrixCSR]
277-
278-
Base.reshape(A::$SparseMatrixType{T,M}, dims::NTuple{N,Int}) where {T,N,M} =
279-
$SparseMatrixType( reshape(CuSparseMatrixCOO(A), dims) )
195+
Base.$op(A::CuSparseVector{T}, B::CuSparseVector{T}) where {T <: BlasFloat} = axpby(one(T), A, $(op)(one(T)), B, 'O')
196+
Base.$op(A::Union{CuSparseMatrixCOO{T}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}},
197+
B::Union{CuSparseMatrixCOO{T}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}}) where {T <: BlasFloat} =
198+
CuSparseMatrixCOO($(op)(CuSparseMatrixCSR(A), CuSparseMatrixCSR(B)))
199+
end
280200

281-
function SparseArrays.droptol!(A::$SparseMatrixType{T,M}, tol::Real) where {T,M}
282-
B = copy(CuSparseMatrixCOO(A))
283-
droptol!(B, tol)
284-
copyto!(A, $SparseMatrixType(B))
201+
for (wrap1, unwrap1) in adjtrans_wrappers, (wrap2, unwrap2) in adjtrans_wrappers
202+
for SparseMatrixType in (:(CuSparseMatrixCSC{T}), :(CuSparseMatrixCOO{T}), :(CuSparseMatrixBSR{T}))
203+
Type1 = wrap1(:(CuSparseMatrixCSR{T}))
204+
Type2 = wrap2(SparseMatrixType)
205+
@eval begin
206+
Base.$op(A::$Type1, B::$Type2) where {T <: BlasFloat} = $(op)($(unwrap1(:A)), CuSparseMatrixCSR(B))
207+
Base.$op(A::$Type2, B::$Type1) where {T <: BlasFloat} = $(op)(CuSparseMatrixCSR(A), $(unwrap1(:B)))
285208
end
286-
287209
end
288210

289-
LinearAlgebra.:(*)(A::Transpose{T,<:$SparseMatrixType}, B::$SparseMatrixType{T,M}) where {T,M} = _sptranspose(parent(A)) * B
290-
LinearAlgebra.:(*)(A::Transpose{T,<:$SparseMatrixType}, B::Transpose{T,<:$SparseMatrixType}) where {T} = _sptranspose(parent(A)) * _sptranspose(parent(B))
291-
LinearAlgebra.:(*)(A::$SparseMatrixType{T,M}, B::Transpose{T,<:$SparseMatrixType}) where {T,M} = A * _sptranspose(parent(B))
292-
LinearAlgebra.:(*)(A::Adjoint{T,<:$SparseMatrixType}, B::$SparseMatrixType{T,M}) where {T,M} = _spadjoint(parent(A)) * B
293-
LinearAlgebra.:(*)(A::Adjoint{T,<:$SparseMatrixType}, B::Adjoint{T,<:$SparseMatrixType}) where {T} = _spadjoint(parent(A)) * _spadjoint(parent(B))
294-
LinearAlgebra.:(*)(A::$SparseMatrixType{T,M}, B::Adjoint{T,<:$SparseMatrixType}) where {T,M} = A * _spadjoint(parent(B))
211+
for SparseMatrixType in (:(CuSparseMatrixCOO{T}), :(CuSparseMatrixBSR{T}))
212+
Type1 = wrap1(:(CuSparseMatrixCSC{T}))
213+
Type2 = wrap2(SparseMatrixType)
214+
@eval begin
215+
Base.$op(A::$Type1, B::$Type2) where {T <: BlasFloat} = $(op)($(unwrap1(:A)), CuSparseMatrixCSC(B))
216+
Base.$op(A::$Type2, B::$Type1) where {T <: BlasFloat} = $(op)(CuSparseMatrixCSC(A), $(unwrap1(:B)))
217+
end
218+
end
295219
end
296220
end
297221

298-
function _spadjoint(A::CuSparseMatrixCSR{T,M}) where {T,M}
299-
cscA = CuSparseMatrixCSC(conj(A))
300-
CuSparseMatrixCSR(cscA.colPtr, cscA.rowVal, cscA.nzVal, reverse(size(cscA)))
301-
end
302-
function _sptranspose(A::CuSparseMatrixCSR{T,M}) where {T,M}
303-
cscA = CuSparseMatrixCSC(A)
304-
CuSparseMatrixCSR(cscA.colPtr, cscA.rowVal, cscA.nzVal, reverse(size(cscA)))
305-
end
306-
function _spadjoint(A::CuSparseMatrixCSC{T,M}) where {T,M}
307-
CuSparseMatrixCSC(CuSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A))))
308-
end
309-
function _sptranspose(A::CuSparseMatrixCSC{T,M}) where {T,M}
310-
CuSparseMatrixCSC(CuSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A))))
311-
end
312-
function _spadjoint(A::CuSparseMatrixCOO{T,M}) where {T,M}
313-
sparse(A.colInd, A.rowInd, conj(A.nzVal), size(A)..., fmt = :coo)
314-
end
315-
function _sptranspose(A::CuSparseMatrixCOO{T,M}) where {T,M}
316-
sparse(A.colInd, A.rowInd, A.nzVal, size(A)..., fmt = :coo)
317-
end
318-
319-
320222
# triangular
321223
for SparseMatrixType in (:CuSparseMatrixBSR, :CuSparseMatrixCSC, :CuSparseMatrixCSR)
322224

0 commit comments

Comments
 (0)