Skip to content

Commit 4958827

Browse files
committed
Fix?
1 parent 811a16c commit 4958827

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/host/linalg.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,26 +366,29 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
366366
# number of tiles depends on inner dimension
367367
@uniform NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM)
368368

369-
glob_I = (grow - 1) * TILE_DIM + tile_row
370-
glob_J = (gcol - 1) * TILE_DIM + tile_col
371-
372369
# loop over all tiles needed for this calculation
373370
for t in 0:(NUM_TILES - 1)
371+
I = (grow - 1) * TILE_DIM + tile_row
372+
J = (gcol - 1) * TILE_DIM + tile_col
373+
374374
# load inputs into tiles, with bounds checking for non-square matrices
375-
if glob_I <= N && t * TILE_DIM + tile_col <= Q
376-
@inbounds tile1[tile_row, tile_col] = input1[glob_I, t * TILE_DIM + tile_col]
375+
if I <= N && t * TILE_DIM + tile_col <= Q
376+
@inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
377377
else
378378
@inbounds tile1[tile_row, tile_col] = zero(R)
379379
end
380-
if glob_J <= M && t * TILE_DIM + tile_row <= Q
381-
@inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, glob_J]
380+
if J <= M && t * TILE_DIM + tile_row <= Q
381+
@inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
382382
else
383383
@inbounds tile2[tile_row, tile_col] = zero(R)
384384
end
385385

386386
# wait for all tiles to be loaded
387387
@synchronize
388388

389+
I = (grow - 1) * TILE_DIM + tile_row
390+
J = (gcol - 1) * TILE_DIM + tile_col
391+
389392
# calculate value of spot in output, use temporary value to allow for vectorization
390393
out = zero(R)
391394
@simd for k in 1:TILE_DIM
@@ -396,9 +399,12 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
396399
@synchronize
397400
end
398401

402+
I = (grow - 1) * TILE_DIM + tile_row
403+
J = (gcol - 1) * TILE_DIM + tile_col
404+
399405
# save if inbounds
400-
if glob_I <= N && glob_J <= M
401-
@inbounds output[glob_I, glob_J] = add(outval[1], output[glob_I, glob_J])
406+
if I <= N && J <= M
407+
@inbounds output[I, J] = add(outval[1], output[I, J])
402408
end
403409
end
404410

0 commit comments

Comments
 (0)