Skip to content

Commit 70ac953

Browse files
authored
Fix blas tests (#627)
1 parent d3bf4c1 commit 70ac953

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

src/blas/highlevel.jl

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@ end
8989
if VERSION v"1.10-"
9090
# multiplication
9191
LinearAlgebra.generic_trimatmul!(
92-
c::ROCVector{T}, uploc, isunitc, tfun::Function,
93-
A::ROCMatrix{T}, b::AbstractVector{T},
92+
c::StridedROCVector{T}, uploc, isunitc, tfun::Function,
93+
A::StridedROCMatrix{T}, b::StridedROCVector{T},
9494
) where T <: ROCBLASFloat = trmv!(
9595
uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
9696
isunitc, A, c === b ? c : copyto!(c, b))
9797
# division
9898
LinearAlgebra.generic_trimatdiv!(
99-
C::ROCVector{T}, uploc, isunitc, tfun::Function,
100-
A::ROCMatrix{T}, B::AbstractVector{T},
99+
C::StridedROCVector{T}, uploc, isunitc, tfun::Function,
100+
A::StridedROCMatrix{T}, B::StridedROCVector{T},
101101
) where T <: ROCBLASFloat = trsv!(
102102
uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
103103
isunitc, A, C === B ? C : copyto!(C, B))
@@ -410,3 +410,43 @@ else
410410
end
411411
end
412412
end
413+
414+
# Matrix inversion.
415+
416+
for (t, uploc, isunitc) in (
417+
(:LowerTriangular, 'U', 'N'),
418+
(:UnitLowerTriangular, 'U', 'U'),
419+
(:UpperTriangular, 'L', 'N'),
420+
(:UnitUpperTriangular, 'L', 'U'),
421+
)
422+
@eval function LinearAlgebra.inv(x::$t{T, <: ROCMatrix{T}}) where T <: ROCBLASFloat
423+
out = ROCArray{T}(I(size(x, 1)))
424+
$t(LinearAlgebra.ldiv!(x, out))
425+
end
426+
end
427+
428+
# Diagonal matrix.
429+
430+
Base.Array(D::Diagonal{T, <: ROCArray{T}}) where T = Diagonal(Array(D.diag))
431+
432+
ROCArray(D::Diagonal{T, <: Vector{T}}) where T = Diagonal(ROCArray(D.diag))
433+
434+
function LinearAlgebra.inv(D::Diagonal{T, <: ROCArray{T}}) where T
435+
Di = map(inv, D.diag)
436+
any(isinf, Di) && error("Singular Exception $Di")
437+
Diagonal(Di)
438+
end
439+
440+
function Base.:/(A::ROCArray, D::Diagonal)
441+
B = similar(A, typeof(oneunit(eltype(A)) / oneunit(eltype(D))))
442+
_rdiv!(B, A, D)
443+
end
444+
445+
function _rdiv!(B::ROCArray, A::ROCArray, D::Diagonal)
446+
m, n = size(A, 1), size(A, 2)
447+
(k = length(D.diag)) != n && throw(DimensionMismatch(
448+
"left hand side has $n columns but D is $k by $k"))
449+
450+
B .= A * inv(D)
451+
return B
452+
end

test/rocarray/blas.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ m = 20
88
n = 35
99
k = 13
1010

11-
handle = rocBLAS.handle()
12-
1311
@testset "Build Information" begin
1412
ver = rocBLAS.version()
1513
@test ver isa VersionNumber
@@ -103,13 +101,13 @@ end
103101

104102
A = rand(T, m, m)
105103
x = rand(T, m)
106-
@testset "Triangular mul/lmul!" for TR in (
104+
@testset "Triangular mul/lmul!" for TR in (
107105
UpperTriangular, LowerTriangular,
108106
), f in (
109107
identity, adjoint, transpose,
110108
)
111-
@test testf((a, b) -> f(TR(A)) * x, A, x)
112-
@test testf((a, b) -> lmul!(f(TR(A)), b), A, copy(x))
109+
@test testf((a, b) -> f(TR(a)) * b, A, x)
110+
@test testf((a, b) -> lmul!(f(TR(a)), b), A, copy(x))
113111
end
114112

115113
A, x = rand(T, m, m), rand(T, m)
@@ -118,15 +116,17 @@ end
118116
), f in (
119117
identity, adjoint, transpose,
120118
)
121-
@test testf((a, b) -> f(TR(A)) \ x, A, x)
122-
@test testf((a, b) -> ldiv!(f(TR(A)), b), A, copy(x))
119+
@test testf((a, b) -> f(TR(a)) \ b, A, x)
120+
@test testf((a, b) -> ldiv!(f(TR(a)), b), A, copy(x))
123121
end
124122

123+
x = rand(T, m, m)
125124
@testset "inv($TR)" for TR in (
126125
UpperTriangular, LowerTriangular,
127126
UnitUpperTriangular, UnitLowerTriangular,
127+
Diagonal,
128128
)
129-
@test testf(x -> inv(TR(x)), rand(T, m, m))
129+
@test testf(a -> inv(TR(a)), x)
130130
end
131131
end
132132
end

0 commit comments

Comments
 (0)