Skip to content

Commit 40fa8c0

Browse files
committed
Revert "Reconstruct Broadcasted in kernel to help Enzyme.jl (JuliaGPU#539)"
This reverts commit 8c5d550.
1 parent cd1f59a commit 40fa8c0

File tree

2 files changed

+7
-29
lines changed

2 files changed

+7
-29
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GPUArrays"
22
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
3-
version = "10.2.1"
3+
version = "10.2.2"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/host/broadcast.jl

+6-28
Original file line numberDiff line numberDiff line change
@@ -47,61 +47,39 @@ end
4747
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
4848
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
4949
isempty(dest) && return dest
50-
51-
# to help Enzyme.jl, we won't pass the broadcasted object directly
52-
# but instead pass its arguments and reconstruct the object device-side
5350
bc = Broadcast.preprocess(dest, bc)
54-
bcstyle = @static if VERSION >= v"1.10-"
55-
bc.style
56-
else
57-
typeof(BroadcastStyle(typeof(bc)))
58-
end
5951

6052
broadcast_kernel = if ndims(dest) == 1 ||
6153
(isa(IndexStyle(dest), IndexLinear) &&
6254
isa(IndexStyle(bc), IndexLinear))
63-
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
64-
bc′ = @static if VERSION >= v"1.10-"
65-
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
66-
else
67-
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
68-
end
69-
55+
function (ctx, dest, bc, nelem)
7056
i = 1
7157
while i <= nelem
7258
I = @linearidx(dest, i)
73-
@inbounds dest[I] = bc[I]
59+
@inbounds dest[I] = bc[I]
7460
i += 1
7561
end
7662
return
7763
end
7864
else
79-
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
80-
bc′ = @static if VERSION >= v"1.10-"
81-
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
82-
else
83-
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
84-
end
85-
65+
function (ctx, dest, bc, nelem)
8666
i = 0
8767
while i < nelem
8868
i += 1
8969
I = @cartesianidx(dest, i)
90-
@inbounds dest[I] = bc[I]
70+
@inbounds dest[I] = bc[I]
9171
end
9272
return
9373
end
9474
end
9575

9676
elements = length(dest)
9777
elements_per_thread = typemax(Int)
98-
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1,
99-
bcstyle, bc.f, bc.axes, bc.args...;
78+
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1;
10079
elements, elements_per_thread)
10180
config = launch_configuration(backend(dest), heuristic;
10281
elements, elements_per_thread)
103-
gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int,
104-
bcstyle, bc.f, bc.axes, bc.args...;
82+
gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread;
10583
threads=config.threads, blocks=config.blocks)
10684

10785
if eltype(dest) <: BrokenBroadcast

0 commit comments

Comments
 (0)