|
47 | 47 | @inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
|
48 | 48 | axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
|
49 | 49 | isempty(dest) && return dest
|
50 |
| - |
51 |
| - # to help Enzyme.jl, we won't pass the broadcasted object directly |
52 |
| - # but instead pass its arguments and reconstruct the object device-side |
53 | 50 | bc = Broadcast.preprocess(dest, bc)
|
54 |
| - bcstyle = @static if VERSION >= v"1.10-" |
55 |
| - bc.style |
56 |
| - else |
57 |
| - typeof(BroadcastStyle(typeof(bc))) |
58 |
| - end |
59 | 51 |
|
60 | 52 | broadcast_kernel = if ndims(dest) == 1 ||
|
61 | 53 | (isa(IndexStyle(dest), IndexLinear) &&
|
62 | 54 | isa(IndexStyle(bc), IndexLinear))
|
63 |
| - function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) |
64 |
| - bc′ = @static if VERSION >= v"1.10-" |
65 |
| - Broadcasted(bcstyle, bcf, bcargs, bcaxes) |
66 |
| - else |
67 |
| - Broadcasted{bcstyle}(bcf, bcargs, bcaxes) |
68 |
| - end |
69 |
| - |
| 55 | + function (ctx, dest, bc, nelem) |
70 | 56 | i = 1
|
71 | 57 | while i <= nelem
|
72 | 58 | I = @linearidx(dest, i)
|
73 |
| - @inbounds dest[I] = bc′[I] |
| 59 | + @inbounds dest[I] = bc[I] |
74 | 60 | i += 1
|
75 | 61 | end
|
76 | 62 | return
|
77 | 63 | end
|
78 | 64 | else
|
79 |
| - function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) |
80 |
| - bc′ = @static if VERSION >= v"1.10-" |
81 |
| - Broadcasted(bcstyle, bcf, bcargs, bcaxes) |
82 |
| - else |
83 |
| - Broadcasted{bcstyle}(bcf, bcargs, bcaxes) |
84 |
| - end |
85 |
| - |
| 65 | + function (ctx, dest, bc, nelem) |
86 | 66 | i = 0
|
87 | 67 | while i < nelem
|
88 | 68 | i += 1
|
89 | 69 | I = @cartesianidx(dest, i)
|
90 |
| - @inbounds dest[I] = bc′[I] |
| 70 | + @inbounds dest[I] = bc[I] |
91 | 71 | end
|
92 | 72 | return
|
93 | 73 | end
|
94 | 74 | end
|
95 | 75 |
|
96 | 76 | elements = length(dest)
|
97 | 77 | elements_per_thread = typemax(Int)
|
98 |
| - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1, |
99 |
| - bcstyle, bc.f, bc.axes, bc.args...; |
| 78 | + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1; |
100 | 79 | elements, elements_per_thread)
|
101 | 80 | config = launch_configuration(backend(dest), heuristic;
|
102 | 81 | elements, elements_per_thread)
|
103 |
| - gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int, |
104 |
| - bcstyle, bc.f, bc.axes, bc.args...; |
| 82 | + gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread; |
105 | 83 | threads=config.threads, blocks=config.blocks)
|
106 | 84 |
|
107 | 85 | if eltype(dest) <: BrokenBroadcast
|
|
0 commit comments