Skip to content

Commit 618c840

Browse files
committed
Separate algorithms
1 parent db5abc5 commit 618c840

File tree

1 file changed

+13
-52
lines changed

1 file changed

+13
-52
lines changed

Diff for: src/reduce.jl

+13-52
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
export @groupreduce
2-
3-
module Reduction
4-
const thread = Val(:thread)
5-
const warp = Val(:warp)
6-
end
1+
export @groupreduce, @warp_groupreduce
72

83
"""
94
@groupreduce op val neutral [groupsize]
@@ -25,55 +20,21 @@ If backend supports warp reduction, it will use it instead of thread reduction.
2520
2621
Result of the reduction.
2722
"""
28-
macro groupreduce(op, val, neutral)
29-
return quote
30-
if __supports_warp_reduction()
31-
__groupreduce(
32-
$(esc(:__ctx__)),
33-
$(esc(op)),
34-
$(esc(val)),
35-
$(esc(neutral)),
36-
Val(prod($groupsize($(esc(:__ctx__))))),
37-
$(esc(Reduction.warp)),
38-
)
39-
else
40-
__groupreduce(
41-
$(esc(:__ctx__)),
42-
$(esc(op)),
43-
$(esc(val)),
44-
$(esc(neutral)),
45-
Val(prod($groupsize($(esc(:__ctx__))))),
46-
$(esc(Reduction.thread)),
47-
)
48-
end
49-
end
23+
macro groupreduce(op, val)
24+
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val(prod($groupsize($(esc(:__ctx__)))))))
25+
end
26+
macro groupreduce(op, val, groupsize)
27+
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val($(esc(groupsize)))))
5028
end
5129

52-
macro groupreduce(op, val, neutral, groupsize)
53-
return quote
54-
if __supports_warp_reduction()
55-
__groupreduce(
56-
$(esc(:__ctx__)),
57-
$(esc(op)),
58-
$(esc(val)),
59-
$(esc(neutral)),
60-
Val($(esc(groupsize))),
61-
$(esc(Reduction.warp)),
62-
)
63-
else
64-
__groupreduce(
65-
$(esc(:__ctx__)),
66-
$(esc(op)),
67-
$(esc(val)),
68-
$(esc(neutral)),
69-
Val($(esc(groupsize))),
70-
$(esc(Reduction.thread)),
71-
)
72-
end
73-
end
30+
macro warp_groupreduce(op, val, neutral)
31+
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val(prod($groupsize($(esc(:__ctx__)))))))
32+
end
33+
macro warp_groupreduce(op, val, neutral, groupsize)
34+
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val($(esc(groupsize)))))
7435
end
7536

76-
function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:thread}) where {T, groupsize}
37+
function __thread_groupreduce(__ctx__, op, val::T, ::Val{groupsize}) where {T, groupsize}
7738
storage = @localmem T groupsize
7839

7940
local_idx = @index(Local)
@@ -120,7 +81,7 @@ const __warp_bins = UInt32(32)
12081
return val
12182
end
12283

123-
function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:warp}) where {T, groupsize}
84+
function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize}
12485
storage = @localmem T __warp_bins
12586

12687
local_idx = @index(Local)

0 commit comments

Comments
 (0)