Skip to content

Commit 07680f5

Browse files
committed
[rocSOLVER] Interface geblttrf_npvt and geblttrs_npvt
1 parent 260a295 commit 07680f5

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

src/solver/highlevel.jl

+55
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,61 @@ for (fname, elty) in (
212212
end
213213
end
214214

215+
for (fname, elty) in (
216+
(:rocsolver_sgeblttrf_npvt, :Float32),
217+
(:rocsolver_dgeblttrf_npvt, :Float64),
218+
(:rocsolver_cgeblttrf_npvt, :ComplexF32),
219+
(:rocsolver_zgeblttrf_npvt, :ComplexF64),
220+
)
221+
@eval begin
222+
function geblttrf!(A::ROCArray{$elty,3}, B::ROCArray{$elty,3}, C::ROCArray{$elty,3})
223+
mA, nA, nblocksA = size(A)
224+
mB, nB, nblocksB = size(B)
225+
mC, nC, nblocksC = size(C)
226+
(mA == nA == mB == nB == mC == nC) || throw(DimensionMismatch("The first two dimensions of A, B and C must match"))
227+
(nblocksA == nblocksB - 1 == nblocksC) || throw(DimensionMismatch("Inconsistency for the last dimension of A, B and C"))
228+
229+
lda = max(1, stride(A, 2))
230+
ldb = max(1, stride(B, 2))
231+
ldc = max(1, stride(C, 2))
232+
233+
devinfo = ROCArray{Cint}(undef, 1)
234+
$fname(rocBLAS.handle(), mB, nblocksB, A, lda, B, ldb, C, ldc, devinfo)
235+
info = AMDGPU.@allowscalar devinfo[1]
236+
AMDGPU.unsafe_free!(devinfo)
237+
chkargsok(BlasInt(info))
238+
B, C
239+
end
240+
end
241+
end
242+
243+
for (fname, elty) in (
244+
(:rocsolver_sgeblttrs_npvt, :Float32),
245+
(:rocsolver_dgeblttrs_npvt, :Float64),
246+
(:rocsolver_cgeblttrs_npvt, :ComplexF32),
247+
(:rocsolver_zgeblttrs_npvt, :ComplexF64),
248+
)
249+
@eval begin
250+
function geblttrs!(A::ROCArray{$elty,3}, B::ROCArray{$elty,3}, C::ROCArray{$elty,3}, X::ROCArray{$elty,3})
251+
mA, nA, nblocksA = size(A)
252+
mB, nB, nblocksB = size(B)
253+
mC, nC, nblocksC = size(C)
254+
mX, nblocksX, nrhs = size(X)
255+
(mA == nA == mB == nB == mC == nC) || throw(DimensionMismatch("The first two dimensions of A, B and C must match"))
256+
(mX == mA) || throw(DimensionMismatch("The first dimension of X is inconsistent with first two dimensions of A, B and C"))
257+
(nblocksA == nblocksB - 1 == nblocksX - 1 == nblocksC) || throw(DimensionMismatch("Inconsistency for the number of blocks in A, B, C and X"))
258+
259+
lda = max(1, stride(A, 2))
260+
ldb = max(1, stride(B, 2))
261+
ldc = max(1, stride(C, 2))
262+
ldx = max(1, stride(X, 2))
263+
264+
$fname(rocBLAS.handle(), mB, nblocksB, nrhs, A, lda, B, ldb, C, ldc, X, ldx)
265+
X
266+
end
267+
end
268+
end
269+
215270
for (fname, elty, relty) in ((:rocsolver_sgebrd, :Float32 , :Float32),
216271
(:rocsolver_dgebrd, :Float64 , :Float64),
217272
(:rocsolver_cgebrd, :ComplexF32, :Float32),

test/rocarray/solver.jl

+74
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,80 @@ end
200200
end
201201
end
202202

203+
@testset "geblttrf! -- geblttrs!" begin
204+
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
205+
@testset "n = $n" for n in (1, ) # 8, 16)
206+
nblocks = 5
207+
nrhs = 1
208+
p = n * nblocks
209+
A = rand(elty, n, n, nblocks-1)
210+
B = rand(elty, n, n, nblocks)
211+
C = rand(elty, n, n, nblocks-1)
212+
R = rand(elty, n, nblocks, nrhs)
213+
214+
M = zeros(elty, p, p)
215+
RHS = zeros(elty, p, nrhs)
216+
for k in 1:nblocks
217+
offset = (k-1)*n
218+
for i = 1:n
219+
for j = 1:n
220+
M[offset+i,offset+j] = B[i,j,k]
221+
if k < nblocks
222+
M[offset+n+i,offset+j] = A[i,j,k]
223+
M[offset+i,offset+n+j] = C[i,j,k]
224+
end
225+
end
226+
for j = 1:nrhs
227+
RHS[offset+i,j] = R[i,k,j]
228+
end
229+
end
230+
end
231+
232+
d_A = ROCArray(A)
233+
d_B = ROCArray(B)
234+
d_C = ROCArray(C)
235+
d_R = ROCArray(R)
236+
rocSOLVER.geblttrf!(d_A, d_B, d_C)
237+
238+
L = zeros(elty, p, p)
239+
U = zeros(elty, p, p)
240+
B2 = collect(d_B)
241+
C2 = collect(d_C)
242+
for k in 1:nblocks
243+
offset = (k-1)*n
244+
for i = 1:n
245+
for j = 1:n
246+
if i == j
247+
U[offset+i,offset+j] = one(elty)
248+
end
249+
L[offset+i,offset+j] = B2[j,i,k]
250+
if k < nblocks
251+
L[offset+n+i,offset+j] = A[i,j,k]
252+
U[offset+i,offset+n+j] = C2[j,i,k]
253+
end
254+
255+
end
256+
end
257+
end
258+
N = L * U
259+
@test N M
260+
261+
X = N \ RHS
262+
Y = similar(R)
263+
for k in 1:nblocks
264+
for i = 1:n
265+
for j = 1:nrhs
266+
l = (k-1)*n + i
267+
Y[i, k, j] = X[l,j]
268+
end
269+
end
270+
end
271+
rocSOLVER.geblttrs!(d_A, d_B, d_C, d_R)
272+
@test Y collect(d_R)
273+
end
274+
end
275+
end
276+
203277
@testset "gebrd!" begin
204278
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
205279
A = rand(elty,m,n)

0 commit comments

Comments
 (0)