Skip to content

Commit 344d484

Browse files
committed
Add warp_groupreduce tests
1 parent 618c840 commit 344d484

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

src/reduce.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ end
6262
# Warp groupreduce.
6363

6464
# NOTE: Backends should implement these two device functions (with `@device_override`).
65-
function __shfl_down end
66-
function __supports_warp_reduction()
67-
return false
68-
end
65+
function shfl_down end
66+
supports_warp_reduction() = false
67+
# Host-variant.
68+
supports_warp_reduction(::Backend) = false
6969

7070
# Assume warp is 32 lanes.
7171
const __warpsize = UInt32(32)
@@ -75,7 +75,7 @@ const __warp_bins = UInt32(32)
7575
@inline function __warp_reduce(val, op)
7676
offset::UInt32 = __warpsize ÷ 0x02
7777
while offset > 0x00
78-
val = op(val, __shfl_down(val, offset))
78+
val = op(val, shfl_down(val, offset))
7979
offset >>= 0x01
8080
end
8181
return val

test/groupreduce.jl

+42-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
11
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
22
i = @index(Global)
33
val = i > length(x) ? neutral : x[i]
4-
res = @groupreduce(op, val, neutral)
4+
res = @groupreduce(op, val)
55
i == 1 && (y[1] = res)
66
end
77

88
@kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
99
i = @index(Global)
1010
val = i > length(x) ? neutral : x[i]
11-
res = @groupreduce(op, val, neutral, groupsize)
11+
res = @groupreduce(op, val, groupsize)
12+
i == 1 && (y[1] = res)
13+
end
14+
15+
@kernel cpu=false function warp_groupreduce_1!(y, x, op, neutral)
16+
i = @index(Global)
17+
val = i > length(x) ? neutral : x[i]
18+
res = @warp_groupreduce(op, val, neutral)
19+
i == 1 && (y[1] = res)
20+
end
21+
22+
@kernel cpu=false function warp_groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
23+
i = @index(Global)
24+
val = i > length(x) ? neutral : x[i]
25+
res = @warp_groupreduce(op, val, neutral, groupsize)
1226
i == 1 && (y[1] = res)
1327
end
1428

@@ -17,19 +31,40 @@ function groupreduce_testsuite(backend, AT)
1731
groupsizes = "$backend" == "oneAPIBackend" ?
1832
(256,) :
1933
(256, 512, 1024)
34+
2035
@testset "@groupreduce" begin
2136
@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes
2237
x = AT(ones(T, n))
2338
y = AT(zeros(T, 1))
39+
neutral = zero(T)
40+
op = +
2441

25-
groupreduce_1!(backend(), n)(y, x, +, zero(T); ndrange = n)
42+
groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n)
2643
@test Array(y)[1] == n
2744

28-
groupreduce_2!(backend())(y, x, +, zero(T), Val(128); ndrange = n)
29-
@test Array(y)[1] == 128
45+
for groupsize in (64, 128)
46+
groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n)
47+
@test Array(y)[1] == groupsize
48+
end
49+
end
50+
end
51+
52+
if KernelAbstractions.supports_warp_reduction(backend())
53+
@testset "@warp_groupreduce" begin
54+
@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes
55+
x = AT(ones(T, n))
56+
y = AT(zeros(T, 1))
57+
neutral = zero(T)
58+
op = +
59+
60+
warp_groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n)
61+
@test Array(y)[1] == n
3062

31-
groupreduce_2!(backend())(y, x, +, zero(T), Val(64); ndrange = n)
32-
@test Array(y)[1] == 64
63+
for groupsize in (64, 128)
64+
warp_groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n)
65+
@test Array(y)[1] == groupsize
66+
end
67+
end
3368
end
3469
end
3570
end

0 commit comments

Comments
 (0)