1
- export @groupreduce
2
-
3
- module Reduction
4
- const thread = Val (:thread )
5
- const warp = Val (:warp )
6
- end
1
+ export @groupreduce , @warp_groupreduce
7
2
8
3
"""
9
4
@groupreduce op val neutral [groupsize]
@@ -25,55 +20,21 @@ If backend supports warp reduction, it will use it instead of thread reduction.
25
20
26
21
Result of the reduction.
27
22
"""
28
- macro groupreduce (op, val, neutral)
29
- return quote
30
- if __supports_warp_reduction ()
31
- __groupreduce (
32
- $ (esc (:__ctx__ )),
33
- $ (esc (op)),
34
- $ (esc (val)),
35
- $ (esc (neutral)),
36
- Val (prod ($ groupsize ($ (esc (:__ctx__ ))))),
37
- $ (esc (Reduction. warp)),
38
- )
39
- else
40
- __groupreduce (
41
- $ (esc (:__ctx__ )),
42
- $ (esc (op)),
43
- $ (esc (val)),
44
- $ (esc (neutral)),
45
- Val (prod ($ groupsize ($ (esc (:__ctx__ ))))),
46
- $ (esc (Reduction. thread)),
47
- )
48
- end
49
- end
23
+ macro groupreduce (op, val)
24
+ :(__thread_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), Val (prod ($ groupsize ($ (esc (:__ctx__ )))))))
25
+ end
26
+ macro groupreduce (op, val, groupsize)
27
+ :(__thread_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), Val ($ (esc (groupsize)))))
50
28
end
51
29
52
- macro groupreduce (op, val, neutral, groupsize)
53
- return quote
54
- if __supports_warp_reduction ()
55
- __groupreduce (
56
- $ (esc (:__ctx__ )),
57
- $ (esc (op)),
58
- $ (esc (val)),
59
- $ (esc (neutral)),
60
- Val ($ (esc (groupsize))),
61
- $ (esc (Reduction. warp)),
62
- )
63
- else
64
- __groupreduce (
65
- $ (esc (:__ctx__ )),
66
- $ (esc (op)),
67
- $ (esc (val)),
68
- $ (esc (neutral)),
69
- Val ($ (esc (groupsize))),
70
- $ (esc (Reduction. thread)),
71
- )
72
- end
73
- end
30
+ macro warp_groupreduce (op, val, neutral)
31
+ :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val (prod ($ groupsize ($ (esc (:__ctx__ )))))))
32
+ end
33
+ macro warp_groupreduce (op, val, neutral, groupsize)
34
+ :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val ($ (esc (groupsize)))))
74
35
end
75
36
76
- function __groupreduce (__ctx__, op, val:: T , neutral :: T , :: Val{groupsize} , :: Val{:thread } ) where {T, groupsize}
37
+ function __thread_groupreduce (__ctx__, op, val:: T , :: Val{groupsize} ) where {T, groupsize}
77
38
storage = @localmem T groupsize
78
39
79
40
local_idx = @index (Local)
@@ -120,7 +81,7 @@ const __warp_bins = UInt32(32)
120
81
return val
121
82
end
122
83
123
- function __groupreduce (__ctx__, op, val:: T , neutral:: T , :: Val{groupsize} , :: Val{:warp } ) where {T, groupsize}
84
+ function __warp_groupreduce (__ctx__, op, val:: T , neutral:: T , :: Val{groupsize} ) where {T, groupsize}
124
85
storage = @localmem T __warp_bins
125
86
126
87
local_idx = @index (Local)
0 commit comments