|
89 | 89 | if VERSION ≥ v"1.10-"
|
90 | 90 | # multiplication
|
91 | 91 | LinearAlgebra.generic_trimatmul!(
|
92 |
| - c::ROCVector{T}, uploc, isunitc, tfun::Function, |
93 |
| - A::ROCMatrix{T}, b::AbstractVector{T}, |
| 92 | + c::StridedROCVector{T}, uploc, isunitc, tfun::Function, |
| 93 | + A::StridedROCMatrix{T}, b::StridedROCVector{T}, |
94 | 94 | ) where T <: ROCBLASFloat = trmv!(
|
95 | 95 | uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
|
96 | 96 | isunitc, A, c === b ? c : copyto!(c, b))
|
97 | 97 | # division
|
98 | 98 | LinearAlgebra.generic_trimatdiv!(
|
99 |
| - C::ROCVector{T}, uploc, isunitc, tfun::Function, |
100 |
| - A::ROCMatrix{T}, B::AbstractVector{T}, |
| 99 | + C::StridedROCVector{T}, uploc, isunitc, tfun::Function, |
| 100 | + A::StridedROCMatrix{T}, B::StridedROCVector{T}, |
101 | 101 | ) where T <: ROCBLASFloat = trsv!(
|
102 | 102 | uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
|
103 | 103 | isunitc, A, C === B ? C : copyto!(C, B))
|
@@ -410,3 +410,43 @@ else
|
410 | 410 | end
|
411 | 411 | end
|
412 | 412 | end
|
| 413 | + |
| 414 | +# Matrix inversion. |
| 415 | + |
| 416 | +for (t, uploc, isunitc) in ( |
| 417 | + (:LowerTriangular, 'U', 'N'), |
| 418 | + (:UnitLowerTriangular, 'U', 'U'), |
| 419 | + (:UpperTriangular, 'L', 'N'), |
| 420 | + (:UnitUpperTriangular, 'L', 'U'), |
| 421 | +) |
| 422 | + @eval function LinearAlgebra.inv(x::$t{T, <: ROCMatrix{T}}) where T <: ROCBLASFloat |
| 423 | + out = ROCArray{T}(I(size(x, 1))) |
| 424 | + $t(LinearAlgebra.ldiv!(x, out)) |
| 425 | + end |
| 426 | +end |
| 427 | + |
| 428 | +# Diagonal matrix. |
| 429 | + |
| 430 | +Base.Array(D::Diagonal{T, <: ROCArray{T}}) where T = Diagonal(Array(D.diag)) |
| 431 | + |
| 432 | +ROCArray(D::Diagonal{T, <: Vector{T}}) where T = Diagonal(ROCArray(D.diag)) |
| 433 | + |
| 434 | +function LinearAlgebra.inv(D::Diagonal{T, <: ROCArray{T}}) where T |
| 435 | + Di = map(inv, D.diag) |
| 436 | + any(isinf, Di) && error("Singular Exception $Di") |
| 437 | + Diagonal(Di) |
| 438 | +end |
| 439 | + |
| 440 | +function Base.:/(A::ROCArray, D::Diagonal) |
| 441 | + B = similar(A, typeof(oneunit(eltype(A)) / oneunit(eltype(D)))) |
| 442 | + _rdiv!(B, A, D) |
| 443 | +end |
| 444 | + |
| 445 | +function _rdiv!(B::ROCArray, A::ROCArray, D::Diagonal) |
| 446 | + m, n = size(A, 1), size(A, 2) |
| 447 | + (k = length(D.diag)) != n && throw(DimensionMismatch( |
| 448 | + "left hand side has $n columns but D is $k by $k")) |
| 449 | + |
| 450 | + B .= A * inv(D) |
| 451 | + return B |
| 452 | +end |
0 commit comments