You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
133
-
@evalbegin
134
-
Base.$op(A::$SparseMatrixType{T}, B::$SparseMatrixType{T}) where {T <:BlasFloat} =geam(one(T), A, $(op)(one(T)), B, 'O')
135
-
136
-
Base.$op(A::$SparseMatrixType{T}, B::Adjoint{T,<:$SparseMatrixType}) where {T <:BlasFloat} =geam(one(T), A, $(op)(one(T)), _spadjoint(parent(B)), 'O')
137
-
Base.$op(A::Adjoint{T,<:$SparseMatrixType}, B::$SparseMatrixType{T}) where {T <:BlasFloat} =geam(one(T), _spadjoint(parent(A)), $(op)(one(T)), B, 'O')
138
-
Base.$op(A::Adjoint{T,<:$SparseMatrixType}, B::Adjoint{T,<:$SparseMatrixType}) where {T <:BlasFloat} =geam(one(T), _spadjoint(parent(A)), $(op)(one(T)), _spadjoint(parent(B)), 'O')
139
-
140
-
Base.$op(A::$SparseMatrixType{T}, B::Transpose{T,<:$SparseMatrixType}) where {T <:BlasFloat} =geam(one(T), A, $(op)(one(T)), _sptranspose(parent(B)), 'O')
141
-
Base.$op(A::Transpose{T,<:$SparseMatrixType}, B::$SparseMatrixType{T}) where {T <:BlasFloat} =geam(one(T), _sptranspose(parent(A)), $(op)(one(T)), B, 'O')
142
-
Base.$op(A::Transpose{T,<:$SparseMatrixType}, B::Transpose{T,<:$SparseMatrixType}) where {T <:BlasFloat} =geam(one(T), _sptranspose(parent(A)), $(op)(one(T)), _sptranspose(parent(B)), 'O')
143
-
end
144
-
end
145
-
146
-
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCOO, :CuSparseMatrixBSR)
147
-
@evalbegin
148
-
function Base.$op(A::CuSparseMatrixCSR{T}, B::$SparseMatrixType{T}) where {T}
149
-
csrB =CuSparseMatrixCSR(B)
150
-
returngeam(one(T), A, $(op)(one(T)), csrB, 'O')
151
-
end
152
-
function Base.$op(A::$SparseMatrixType{T}, B::CuSparseMatrixCSR{T}) where {T}
153
-
csrA =CuSparseMatrixCSR(A)
154
-
returngeam(one(T), csrA, $(op)(one(T)), B, 'O')
155
-
end
156
-
function Base.$op(A::Transpose{T, CuSparseMatrixCSR{T}}, B::$SparseMatrixType{T}) where {T}
function Base.$op(A::CuSparseMatrixCSR{T}, B::Transpose{T, $SparseMatrixType}) where {T}
174
-
csrB =CuSparseMatrixCSR(_sptranspose(parent(B)))
175
-
returngeam(one(T), A, $(op)(one(T)), csrB, 'O')
176
-
end
177
-
function Base.$op(A::Transpose{T, $SparseMatrixType}, B::CuSparseMatrixCSR{T}) where {T}
178
-
csrA =CuSparseMatrixCSR(_sptranspose(parent(A)))
179
-
returngeam(one(T), csrA, $(op)(one(T)), B, 'O')
180
-
end
181
149
182
-
function Base.$op(A::CuSparseMatrixCSR{T}, B::Adjoint{T, $SparseMatrixType}) where {T}
183
-
csrB =CuSparseMatrixCSR(_spadjoint(parent(B)))
184
-
returngeam(one(T), A, $(op)(one(T)), csrB, 'O')
185
-
end
186
-
function Base.$op(A::Adjoint{T, $SparseMatrixType}, B::CuSparseMatrixCSR{T}) where {T}
187
-
csrA =CuSparseMatrixCSR(_spadjoint(parent(A)))
188
-
returngeam(one(T), csrA, $(op)(one(T)), B, 'O')
189
-
end
150
+
function LinearAlgebra.mul!(C::$SparseMatrixType{T}, A::$SparseMatrixType{T}, B::$SparseMatrixType{T}, alpha::Number, beta::Number) where {T <:BlasFloat}
151
+
CUSPARSE.version() <v"11.1.1"&&throw(ErrorException("This operation is not supported by the current CUDA version."))
152
+
gemm!('N', 'N', alpha, A, B, beta, C, 'O')
190
153
end
191
154
end
192
155
end
193
156
194
-
195
-
function Base.reshape(A::CuSparseMatrixCOO{T,M}, dims::NTuple{N,Int}) where {T,N,M}
B::Adjoint{T,<:CuSparseMatrixCSR}) where {T,M} =mul!(Y, A, _spadjoint(parent(B)), one(t), zero(T))
224
-
225
-
function LinearAlgebra.mul!(Y::CuSparseMatrixCOO{T,M}, A::Union{CuSparseMatrixCOO{T,M}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}},
226
-
B::Union{CuSparseMatrixCOO{T,M}, Transpose{T,<:CuSparseMatrixCOO}, Adjoint{T,<:CuSparseMatrixCOO}}, alpha::Number, beta::Number) where {T,M}
227
-
228
-
Y2 =CuSparseMatrixCSR(Y)
229
-
A2 =CuSparseMatrixCSR(A)
230
-
B2 =CuSparseMatrixCSR(B)
231
-
mul!(Y2, A2, B2, alpha, beta)
232
-
copyto!(Y, CuSparseMatrixCOO(Y2))
164
+
function LinearAlgebra.mul!(C::CuSparseMatrixCOO{T}, A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}, alpha::Number, beta::Number) where {T <:BlasFloat}
165
+
CUSPARSE.version() <v"11.1.1"&&throw(ErrorException("This operation is not supported by the current CUDA version."))
166
+
A_csr =CuSparseMatrixCSR(A)
167
+
B_csr =CuSparseMatrixCSR(B)
168
+
C_csr =CuSparseMatrixCSR(C)
169
+
mul!(C_csr, A_csr, B_csr, alpha, beta)
170
+
C =CuSparseMatrixCOO(C_csr)
233
171
end
234
-
function LinearAlgebra.mul!(Y::CuSparseMatrixCSC{T,M}, A::Union{CuSparseMatrixCSC{T,M}, Transpose{T,<:CuSparseMatrixCSC}, Adjoint{T,<:CuSparseMatrixCSC}},
235
-
B::Union{CuSparseMatrixCSC{T,M}, Transpose{T,<:CuSparseMatrixCSC}, Adjoint{T,<:CuSparseMatrixCSC}}, alpha::Number, beta::Number) where {T,M}
0 commit comments