Skip to content

Commit 4a5ec1f

Browse files
Reverting changes to mul, keeping triu and tril, adding support for QRQ getproperty and printarray
1 parent dbc77f5 commit 4a5ec1f

File tree

1 file changed

+51
-17
lines changed

1 file changed

+51
-17
lines changed

src/host/linalg.jl

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ if VERSION < v"1.8-"
206206
return B
207207
end
208208
else
209-
function LinearAlgebra.mul!(B::AnyGPUVecOrMat,
210-
D::Diagonal{<:Any, <:AnyGPUArray},
211-
A::AnyGPUVecOrMat)
209+
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
210+
D::Diagonal{<:Any, <:AbstractGPUArray},
211+
A::AbstractGPUVecOrMat)
212212
dd = D.diag
213213
d = length(dd)
214214
m, n = size(A, 1), size(A, 2)
@@ -220,9 +220,9 @@ else
220220
B
221221
end
222222

223-
function LinearAlgebra.mul!(B::AnyGPUVecOrMat,
224-
D::Diagonal{<:Any, <:AnyGPUArray},
225-
A::AnyGPUVecOrMat,
223+
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
224+
D::Diagonal{<:Any, <:AbstractGPUArray},
225+
A::AbstractGPUVecOrMat,
226226
α::Number,
227227
β::Number)
228228
dd = D.diag
@@ -236,9 +236,9 @@ else
236236
B
237237
end
238238

239-
function LinearAlgebra.mul!(B::AnyGPUVecOrMat,
240-
A::AnyGPUVecOrMat,
241-
D::Diagonal{<:Any, <:AnyGPUArray})
239+
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
240+
A::AbstractGPUVecOrMat,
241+
D::Diagonal{<:Any, <:AbstractGPUArray})
242242
dd = D.diag
243243
d = length(dd)
244244
m, n = size(A, 1), size(A, 2)
@@ -250,9 +250,9 @@ else
250250
B
251251
end
252252

253-
function LinearAlgebra.mul!(B::AnyGPUVecOrMat,
254-
A::AnyGPUVecOrMat,
255-
D::Diagonal{<:Any, <:AnyGPUArray},
253+
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
254+
A::AbstractGPUVecOrMat,
255+
D::Diagonal{<:Any, <:AbstractGPUArray},
256256
α::Number,
257257
β::Number)
258258
dd = D.diag
@@ -266,9 +266,9 @@ else
266266
B
267267
end
268268

269-
function LinearAlgebra.ldiv!(B::AnyGPUVecOrMat,
270-
D::Diagonal{<:Any, <:AnyGPUArray},
271-
A::AnyGPUVecOrMat)
269+
function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
270+
D::Diagonal{<:Any, <:AbstractGPUArray},
271+
A::AbstractGPUVecOrMat)
272272
dd = D.diag
273273
d = length(dd)
274274
m, n = size(A, 1), size(A, 2)
@@ -289,7 +289,7 @@ end
289289

290290
## matrix multiplication
291291

292-
function generic_matmatmul!(C::AnyArray{R}, A::AnyArray{T}, B::AnyArray{S}, a::Number, b::Number) where {T,S,R}
292+
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, a::Number, b::Number) where {T,S,R}
293293
if size(A,2) != size(B,1)
294294
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
295295
end
@@ -319,7 +319,26 @@ function generic_matmatmul!(C::AnyArray{R}, A::AnyArray{T}, B::AnyArray{S}, a::N
319319
C
320320
end
321321

322-
LinearAlgebra.mul!(C::AnyGPUVecOrMat, A::AnyGPUVecOrMat, B::AnyGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
322+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
323+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
324+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
325+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
326+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
327+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
328+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
329+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
330+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
331+
332+
# specificity hacks
333+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
334+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
335+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
336+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
337+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
338+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
339+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
340+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
341+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
323342

324343

325344
function generic_rmul!(X::AbstractArray, s::Number)
@@ -512,4 +531,19 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T}
512531
Array(y)[]
513532
end
514533

534+
#getproperty for QR
535+
import LinearAlgebra:QRPackedQ
515536

537+
function LinearAlgebra.getproperty(F::QR{T,<:AnyGPUMatrix{T},<:AnyGPUVector{T}}, d::Symbol) where {T}
538+
m, n = size(F)
539+
if d === :R
540+
return triu!(view(getfield(F, :factors),1:min(m,n), 1:n))
541+
elseif d === :Q
542+
return LinearAlgebra.QRPackedQ(getfield(F, :factors), F.τ)
543+
else
544+
getfield(F, d)
545+
end
546+
end
547+
548+
Base.print_array(io::IO, Q::QRPackedQ) =
549+
Base.print_array(io, collect(adapt(ToArray(), Q)))

0 commit comments

Comments
 (0)