@@ -212,6 +212,61 @@ for (fname, elty) in (
212
212
end
213
213
end
214
214
215
+ for (fname, elty) in (
216
+ (:rocsolver_sgeblttrf_npvt , :Float32 ),
217
+ (:rocsolver_dgeblttrf_npvt , :Float64 ),
218
+ (:rocsolver_cgeblttrf_npvt , :ComplexF32 ),
219
+ (:rocsolver_zgeblttrf_npvt , :ComplexF64 ),
220
+ )
221
+ @eval begin
222
+ function geblttrf! (A:: ROCArray{$elty,3} , B:: ROCArray{$elty,3} , C:: ROCArray{$elty,3} )
223
+ mA, nA, nblocksA = size (A)
224
+ mB, nB, nblocksB = size (B)
225
+ mC, nC, nblocksC = size (C)
226
+ (mA == nA == mB == nB == mC == nC) || throw (DimensionMismatch (" The first two dimensions of A, B and C must match" ))
227
+ (nblocksA == nblocksB - 1 == nblocksC) || throw (DimensionMismatch (" Inconsistency for the last dimension of A, B and C" ))
228
+
229
+ lda = max (1 , stride (A, 2 ))
230
+ ldb = max (1 , stride (B, 2 ))
231
+ ldc = max (1 , stride (C, 2 ))
232
+
233
+ devinfo = ROCArray {Cint} (undef, 1 )
234
+ $ fname (rocBLAS. handle (), mB, nblocksB, A, lda, B, ldb, C, ldc, devinfo)
235
+ info = AMDGPU. @allowscalar devinfo[1 ]
236
+ AMDGPU. unsafe_free! (devinfo)
237
+ chkargsok (BlasInt (info))
238
+ B, C
239
+ end
240
+ end
241
+ end
242
+
243
+ for (fname, elty) in (
244
+ (:rocsolver_sgeblttrs_npvt , :Float32 ),
245
+ (:rocsolver_dgeblttrs_npvt , :Float64 ),
246
+ (:rocsolver_cgeblttrs_npvt , :ComplexF32 ),
247
+ (:rocsolver_zgeblttrs_npvt , :ComplexF64 ),
248
+ )
249
+ @eval begin
250
+ function geblttrs! (A:: ROCArray{$elty,3} , B:: ROCArray{$elty,3} , C:: ROCArray{$elty,3} , X:: ROCArray{$elty,3} )
251
+ mA, nA, nblocksA = size (A)
252
+ mB, nB, nblocksB = size (B)
253
+ mC, nC, nblocksC = size (C)
254
+ mX, nblocksX, nrhs = size (X)
255
+ (mA == nA == mB == nB == mC == nC) || throw (DimensionMismatch (" The first two dimensions of A, B and C must match" ))
256
+ (mX == mA) || throw (DimensionMismatch (" The first dimension of X is inconsistent with first two dimensions of A, B and C" ))
257
+ (nblocksA == nblocksB - 1 == nblocksX - 1 == nblocksC) || throw (DimensionMismatch (" Inconsistency for the number of blocks in A, B, C and X" ))
258
+
259
+ lda = max (1 , stride (A, 2 ))
260
+ ldb = max (1 , stride (B, 2 ))
261
+ ldc = max (1 , stride (C, 2 ))
262
+ ldx = max (1 , stride (X, 2 ))
263
+
264
+ $ fname (rocBLAS. handle (), mB, nblocksB, nrhs, A, lda, B, ldb, C, ldc, X, ldx)
265
+ X
266
+ end
267
+ end
268
+ end
269
+
215
270
for (fname, elty, relty) in ((:rocsolver_sgebrd , :Float32 , :Float32 ),
216
271
(:rocsolver_dgebrd , :Float64 , :Float64 ),
217
272
(:rocsolver_cgebrd , :ComplexF32 , :Float32 ),
0 commit comments