1
1
@kernel cpu= false function groupreduce_1! (y, x, op, neutral)
2
2
i = @index (Global)
3
3
val = i > length (x) ? neutral : x[i]
4
- res = @groupreduce (op, val, neutral )
4
+ res = @groupreduce (op, val)
5
5
i == 1 && (y[1 ] = res)
6
6
end
7
7
8
8
@kernel cpu= false function groupreduce_2! (y, x, op, neutral, :: Val{groupsize} ) where {groupsize}
9
9
i = @index (Global)
10
10
val = i > length (x) ? neutral : x[i]
11
- res = @groupreduce (op, val, neutral, groupsize)
11
+ res = @groupreduce (op, val, groupsize)
12
+ i == 1 && (y[1 ] = res)
13
+ end
14
+
15
+ @kernel cpu= false function warp_groupreduce_1! (y, x, op, neutral)
16
+ i = @index (Global)
17
+ val = i > length (x) ? neutral : x[i]
18
+ res = @warp_groupreduce (op, val, neutral)
19
+ i == 1 && (y[1 ] = res)
20
+ end
21
+
22
+ @kernel cpu= false function warp_groupreduce_2! (y, x, op, neutral, :: Val{groupsize} ) where {groupsize}
23
+ i = @index (Global)
24
+ val = i > length (x) ? neutral : x[i]
25
+ res = @warp_groupreduce (op, val, neutral, groupsize)
12
26
i == 1 && (y[1 ] = res)
13
27
end
14
28
@@ -17,19 +31,40 @@ function groupreduce_testsuite(backend, AT)
17
31
groupsizes = " $backend " == " oneAPIBackend" ?
18
32
(256 ,) :
19
33
(256 , 512 , 1024 )
34
+
20
35
@testset " @groupreduce" begin
21
36
@testset " T=$T , n=$n " for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes
22
37
x = AT (ones (T, n))
23
38
y = AT (zeros (T, 1 ))
39
+ neutral = zero (T)
40
+ op = +
24
41
25
- groupreduce_1! (backend (), n)(y, x, + , zero (T) ; ndrange = n)
42
+ groupreduce_1! (backend (), n)(y, x, op, neutral ; ndrange = n)
26
43
@test Array (y)[1 ] == n
27
44
28
- groupreduce_2! (backend ())(y, x, + , zero (T), Val (128 ); ndrange = n)
29
- @test Array (y)[1 ] == 128
45
+ for groupsize in (64 , 128 )
46
+ groupreduce_2! (backend ())(y, x, op, neutral, Val (groupsize); ndrange = n)
47
+ @test Array (y)[1 ] == groupsize
48
+ end
49
+ end
50
+ end
51
+
52
+ if KernelAbstractions. supports_warp_reduction (backend ())
53
+ @testset " @warp_groupreduce" begin
54
+ @testset " T=$T , n=$n " for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes
55
+ x = AT (ones (T, n))
56
+ y = AT (zeros (T, 1 ))
57
+ neutral = zero (T)
58
+ op = +
59
+
60
+ warp_groupreduce_1! (backend (), n)(y, x, op, neutral; ndrange = n)
61
+ @test Array (y)[1 ] == n
30
62
31
- groupreduce_2! (backend ())(y, x, + , zero (T), Val (64 ); ndrange = n)
32
- @test Array (y)[1 ] == 64
63
+ for groupsize in (64 , 128 )
64
+ warp_groupreduce_2! (backend ())(y, x, op, neutral, Val (groupsize); ndrange = n)
65
+ @test Array (y)[1 ] == groupsize
66
+ end
67
+ end
33
68
end
34
69
end
35
70
end
0 commit comments