338
338
339
339
340
340
# # matrix multiplication
341
-
342
- function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , a:: Number , b:: Number ) where {T,S,R}
341
+ # legacy method
342
+ generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
343
+ generic_matmatmul! (C, A, B, MulAddMul (a, b))
344
+ function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
343
345
if size (A,2 ) != size (B,1 )
344
346
throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
345
347
end
@@ -350,20 +352,18 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
350
352
return fill! (C, zero (R))
351
353
end
352
354
353
- add = MulAddMul (a, b)
354
-
355
355
gpu_call (C, A, B; name= " matmatmul!" ) do ctx, C, A, B
356
356
idx = @linearidx C
357
357
assume .(size (C) .> 0 )
358
358
i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
359
359
360
360
@inbounds if i <= size (A,1 ) && j <= size (B,2 )
361
361
z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
362
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
362
+ Cij = convert (promote_type (R, typeof (z2)), z2)
363
363
for k in 1 : size (A,2 )
364
- Ctmp += A[i, k]* B[k, j]
364
+ Cij += A[i, k]* B[k, j]
365
365
end
366
- C[i,j] = add (Ctmp , C[i,j])
366
+ C[i,j] = add (Cij , C[i,j])
367
367
end
368
368
369
369
return
@@ -372,42 +372,229 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
372
372
C
373
373
end
374
374
375
+ @static if VERSION < v " 1.12.0-"
375
376
function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , _add:: MulAddMul = MulAddMul ())
376
- generic_matmatmul! (C, wrap (A, tA), B, _add. alpha, _add . beta )
377
+ generic_matmatmul! (C, wrap (A, tA), B, _add)
377
378
end
378
379
379
380
function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
380
- generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add. alpha, _add. beta)
381
+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
382
+ end
383
+ else
384
+ function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , a:: Number , b:: Number )
385
+ LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), B, MulAddMul (a, b))
386
+ end
387
+
388
+ function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , a:: Number , b:: Number )
389
+ LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (a, b))
390
+ end
391
+ end
392
+
393
+ function generic_trimatmul! (C:: AbstractGPUVecOrMat{R} , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVecOrMat{S} ) where {T,S,R}
394
+ if size (A,2 ) != size (B,1 )
395
+ throw (DimensionMismatch (lazy " matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))" ))
396
+ end
397
+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
398
+ throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))" ))
399
+ end
400
+ if isempty (A) || isempty (B)
401
+ return fill! (C, zero (R))
402
+ end
403
+
404
+ upper = tfun === identity ? uploc == ' U' : uploc != ' U'
405
+ unit = isunitc == ' U'
406
+
407
+ function trimatmul (ctx, C, A, B)
408
+ idx = @linearidx C
409
+ assume .(size (C) .> 0 )
410
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
411
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
412
+
413
+ @inbounds if i <= l && j <= n
414
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
415
+ Cij = convert (promote_type (R, typeof (z2)), z2)
416
+ Cij += (unit ? one (Cij) : A[i,i]) * B[i,j]
417
+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
418
+ Cij += A[i,k] * B[k,j]
419
+ end
420
+ C[i,j] += Cij
421
+ end
422
+
423
+ return
424
+ end
425
+
426
+ function trimatmul_t (ctx, C, A, B)
427
+ idx = @linearidx C
428
+ assume .(size (C) .> 0 )
429
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
430
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
431
+
432
+ @inbounds if i <= l && j <= n
433
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
434
+ Cij = convert (promote_type (R, typeof (z2)), z2)
435
+ Cij += (unit ? one (Cij) : transpose (A[i,i])) * B[i,j]
436
+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
437
+ Cij += transpose (A[k,i]) * B[k,j]
438
+ end
439
+ C[i,j] += Cij
440
+ end
441
+
442
+ return
443
+ end
444
+
445
+ function trimatmul_a (ctx, C, A, B)
446
+ idx = @linearidx C
447
+ assume .(size (C) .> 0 )
448
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
449
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
450
+
451
+ @inbounds if i <= l && j <= n
452
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
453
+ Cij = convert (promote_type (R, typeof (z2)), z2)
454
+ Cij += (unit ? one (Cij) : adjoint (A[i,i])) * B[i,j]
455
+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
456
+ Cij += adjoint (A[k,i]) * B[k,j]
457
+ end
458
+ C[i,j] += Cij
459
+ end
460
+
461
+ return
462
+ end
463
+
464
+ if tfun === identity
465
+ gpu_call (trimatmul, C, A, B; name= " trimatmul" )
466
+ elseif tfun == transpose
467
+ gpu_call (trimatmul_t, C, A, B; name= " trimatmul_t" )
468
+ elseif tfun === adjoint
469
+ gpu_call (trimatmul_a, C, A, B; name= " trimatmul_a" )
470
+ else
471
+ error (" Not supported" )
472
+ end
473
+
474
+ C
475
+ end
476
+
477
+ function generic_mattrimul! (C:: AbstractGPUVecOrMat{R} , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVecOrMat{S} ) where {T,S,R}
478
+ if size (A,2 ) != size (B,1 )
479
+ throw (DimensionMismatch (lazy " matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))" ))
480
+ end
481
+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
482
+ throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))" ))
483
+ end
484
+ if isempty (A) || isempty (B)
485
+ return fill! (C, zero (R))
486
+ end
487
+
488
+ upper = tfun === identity ? uploc == ' U' : uploc != ' U'
489
+ unit = isunitc == ' U'
490
+
491
+ function mattrimul (ctx, C, A, B)
492
+ idx = @linearidx C
493
+ assume .(size (C) .> 0 )
494
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
495
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
496
+
497
+ @inbounds if i <= l && j <= n
498
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
499
+ Cij = convert (promote_type (R, typeof (z2)), z2)
500
+ Cij += A[i,j] * (unit ? one (Cij) : B[j,j])
501
+ for k in (upper ? 1 : (j + 1 )): (upper ? (j - 1 ) : m)
502
+ Cij += A[i,k] * B[k,j]
503
+ end
504
+ C[i,j] += Cij
505
+ end
506
+
507
+ return
508
+ end
509
+
510
+ function mattrimul_t (ctx, C, A, B)
511
+ idx = @linearidx C
512
+ assume .(size (C) .> 0 )
513
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
514
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
515
+
516
+ @inbounds if i <= l && j <= n
517
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
518
+ Cij = convert (promote_type (R, typeof (z2)), z2)
519
+ Cij += A[i,j] * (unit ? one (Cij) : transpose (B[j,j]))
520
+ for k in (upper ? 1 : (j + 1 ) ): (upper ? (j - 1 ) : m)
521
+ Cij += A[i,k] * transpose (B[j,k])
522
+ end
523
+ C[i,j] += Cij
524
+ end
525
+
526
+ return
527
+ end
528
+
529
+ function mattrimul_a (ctx, C, A, B)
530
+ idx = @linearidx C
531
+ assume .(size (C) .> 0 )
532
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
533
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
534
+
535
+ @inbounds if i <= l && j <= n
536
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
537
+ Cij = convert (promote_type (R, typeof (z2)), z2)
538
+ Cij += A[i,j] * (unit ? one (Cij) : adjoint (B[j,j]))
539
+ for k in (upper ? 1 : (j + 1 )): (upper ? (j - 1 ) : m)
540
+ Cij += A[i,k] * adjoint (B[j,k])
541
+ end
542
+ C[i,j] += Cij
543
+ end
544
+
545
+ return
546
+ end
547
+
548
+ if tfun === identity
549
+ gpu_call (mattrimul, C, A, B; name= " mattrimul" )
550
+ elseif tfun == transpose
551
+ gpu_call (mattrimul_t, C, A, B; name= " mattrimul_t" )
552
+ elseif tfun === adjoint
553
+ gpu_call (mattrimul_a, C, A, B; name= " mattrimul_a" )
554
+ else
555
+ error (" Not supported" )
556
+ end
557
+
558
+ C
559
+ end
560
+
561
+ if VERSION >= v " 1.10-"
562
+ function LinearAlgebra. generic_trimatmul! (C:: AbstractGPUVecOrMat , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix , B:: AbstractGPUVecOrMat )
563
+ generic_trimatmul! (C, uploc, isunitc, tfun, A, B)
564
+ end
565
+ function LinearAlgebra. generic_mattrimul! (C:: AbstractGPUMatrix , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix , B:: AbstractGPUMatrix )
566
+ generic_mattrimul! (C, uploc, isunitc, tfun, A, B)
567
+ end
381
568
end
382
569
383
570
if VERSION < v " 1.10.0-DEV.1365"
384
571
# catch other functions that are called by LinearAlgebra's mul!
385
572
function LinearAlgebra. gemv! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , a:: Number , b:: Number )
386
- generic_matmatmul! (C, wrap (A, tA), B, a, b)
573
+ generic_matmatmul! (C, wrap (A, tA), B, MulAddMul ( a, b) )
387
574
end
388
575
# disambiguation
389
576
function LinearAlgebra. gemv! (C:: AbstractGPUVector{T} , tA:: AbstractChar , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVector{T} , a:: Number , b:: Number ) where {T<: LinearAlgebra.BlasFloat }
390
- generic_matmatmul! (C, wrap (A, tA), B, a, b)
577
+ generic_matmatmul! (C, wrap (A, tA), B, MulAddMul ( a, b) )
391
578
end
392
579
393
580
LinearAlgebra. gemm_wrapper! (C:: AbstractGPUVecOrMat , tA:: AbstractChar , tB:: AbstractChar , A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , _add:: MulAddMul ) =
394
- LinearAlgebra . generic_matmatmul! (C, tA, tB, A, B , _add)
581
+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
395
582
# disambiguation
396
583
LinearAlgebra. gemm_wrapper! (C:: AbstractGPUVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar , A:: AbstractGPUVecOrMat{T} , B:: AbstractGPUVecOrMat{T} , _add:: MulAddMul ) where {T<: LinearAlgebra.BlasFloat } =
397
- LinearAlgebra . generic_matmatmul! (C, tA, tB, A, B , _add)
584
+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
398
585
399
586
function LinearAlgebra. syrk_wrapper! (C:: AbstractGPUMatrix , tA:: AbstractChar , A:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
400
587
if tA == ' T'
401
- LinearAlgebra . generic_matmatmul! (C, ' T ' , ' N ' , A , A, _add)
588
+ generic_matmatmul! (C, wrap (A , ' T ' ) , A, _add)
402
589
else # tA == 'N'
403
- LinearAlgebra . generic_matmatmul! (C, ' N ' , ' T ' , A, A , _add)
590
+ generic_matmatmul! (C, A, wrap ( A, ' T ' ) , _add)
404
591
end
405
592
end
406
593
function LinearAlgebra. herk_wrapper! (C:: AbstractGPUMatrix , tA:: AbstractChar , A:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
407
594
if tA == ' C'
408
- LinearAlgebra . generic_matmatmul! (C, ' C ' , ' N ' , A , A, _add)
595
+ generic_matmatmul! (C, wrap (A , ' C ' ) , A, _add)
409
596
else # tA == 'N'
410
- LinearAlgebra . generic_matmatmul! (C, ' N ' , ' C ' , A, A , _add)
597
+ generic_matmatmul! (C, A, wrap ( A, ' C ' ) , _add)
411
598
end
412
599
end
413
600
end # VERSION
0 commit comments