@@ -366,26 +366,29 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
366
366
# number of tiles depends on inner dimension
367
367
@uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
368
368
369
- glob_I = (grow - 1 ) * TILE_DIM + tile_row
370
- glob_J = (gcol - 1 ) * TILE_DIM + tile_col
371
-
372
369
# loop over all tiles needed for this calculation
373
370
for t in 0 : (NUM_TILES - 1 )
371
+ I = (grow - 1 ) * TILE_DIM + tile_row
372
+ J = (gcol - 1 ) * TILE_DIM + tile_col
373
+
374
374
# load inputs into tiles, with bounds checking for non-square matrices
375
- if glob_I <= N && t * TILE_DIM + tile_col <= Q
376
- @inbounds tile1[tile_row, tile_col] = input1[glob_I , t * TILE_DIM + tile_col]
375
+ if I <= N && t * TILE_DIM + tile_col <= Q
376
+ @inbounds tile1[tile_row, tile_col] = input1[I , t * TILE_DIM + tile_col]
377
377
else
378
378
@inbounds tile1[tile_row, tile_col] = zero (R)
379
379
end
380
- if glob_J <= M && t * TILE_DIM + tile_row <= Q
381
- @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, glob_J ]
380
+ if J <= M && t * TILE_DIM + tile_row <= Q
381
+ @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J ]
382
382
else
383
383
@inbounds tile2[tile_row, tile_col] = zero (R)
384
384
end
385
385
386
386
# wait for all tiles to be loaded
387
387
@synchronize
388
388
389
+ I = (grow - 1 ) * TILE_DIM + tile_row
390
+ J = (gcol - 1 ) * TILE_DIM + tile_col
391
+
389
392
# calculate value of spot in output, use temporary value to allow for vectorization
390
393
out = zero (R)
391
394
@simd for k in 1 : TILE_DIM
@@ -396,9 +399,12 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
396
399
@synchronize
397
400
end
398
401
402
+ I = (grow - 1 ) * TILE_DIM + tile_row
403
+ J = (gcol - 1 ) * TILE_DIM + tile_col
404
+
399
405
# save if inbounds
400
- if glob_I <= N && glob_J <= M
401
- @inbounds output[glob_I, glob_J ] = add (outval[1 ], output[glob_I, glob_J ])
406
+ if I <= N && J <= M
407
+ @inbounds output[I, J ] = add (outval[1 ], output[I, J ])
402
408
end
403
409
end
404
410
0 commit comments