Skip to content

Commit 24ad256

Browse files
authored
Keep track of stream usage for arrays (#633)
- Add Managed struct that keeps track of TLS streams and correctly switches between them. This in turn allows us to **not** use global stream in finalizers. - Remove REPL hook.
1 parent d8428f4 commit 24ad256

File tree

13 files changed

+139
-130
lines changed

13 files changed

+139
-130
lines changed

src/AMDGPU.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ include("runtime/Runtime.jl")
7272
import .Runtime
7373
import .Runtime: Mem, ROCDim, ROCDim3
7474

75-
include("stats.jl")
75+
include("memory.jl")
7676

7777
const ci_cache = GPUCompiler.CodeCache()
7878
Base.Experimental.@MethodTable(method_table)
@@ -217,12 +217,6 @@ function __init__()
217217
@warn "$name is unavailable, functionality will be disabled."
218218
end
219219
end
220-
221-
# Ensure that operations executed by the REPL backend finish before returning.
222-
# Displaying values happens on a different task.
223-
if isdefined(Base, :active_repl_backend)
224-
push!(Base.active_repl_backend.ast_transforms, synchronize_rocm_tasks)
225-
end
226220
end
227221

228222
end

src/array.jl

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,36 @@
11
mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
2-
buf::DataRef{B}
2+
buf::DataRef{Managed{B}}
33
dims::Dims{N}
44
offset::Int # Offset is in number of elements (not bytes).
55

66
function ROCArray{T, N, B}(
77
::UndefInitializer, dims::Dims{N},
88
) where {T, N, B <: Union{Mem.HIPBuffer, Mem.HostBuffer}}
99
@assert isbitstype(T) "ROCArray only supports bits types"
10-
buf = B(prod(dims) * sizeof(T); stream=stream())
11-
xs = new{T, N, B}(DataRef(_free_buf, buf), dims, 0)
12-
finalizer(unsafe_finalize!, xs)
10+
data = DataRef(pool_free, pool_alloc(B, prod(dims) * sizeof(T)))
11+
xs = new{T, N, B}(data, dims, 0)
12+
finalizer(unsafe_free!, xs)
1313
return xs
1414
end
1515

1616
function ROCArray{T, N}(
17-
buf::DataRef{B}, dims::Dims{N}; offset::Integer = 0,
17+
buf::DataRef{Managed{B}}, dims::Dims{N}; offset::Integer = 0,
1818
) where {T, N, B <: Union{Mem.HIPBuffer, Mem.HostBuffer}}
1919
@assert isbitstype(T) "ROCArray only supports bits types"
2020
xs = new{T, N, B}(buf, dims, offset)
21-
finalizer(unsafe_finalize!, xs)
21+
finalizer(unsafe_free!, xs)
2222
return xs
2323
end
2424
end
2525

26-
# Passed to `DataRef` to handle freeing.
27-
function _free_buf(buf, stream_ordered::Bool)
28-
context!(buf.ctx) do
29-
s = stream_ordered ? AMDGPU.stream() : AMDGPU.default_stream()
30-
Mem.free(buf; stream=s)
31-
end
32-
end
33-
34-
unsafe_free!(x::ROCArray) = GPUArrays.unsafe_free!(x.buf, true)
35-
unsafe_finalize!(x::ROCArray) = GPUArrays.unsafe_free!(x.buf, false) # Use global stream.
26+
unsafe_free!(x::ROCArray) = GPUArrays.unsafe_free!(x.buf)
3627

3728
"""
3829
device(A::ROCArray) -> HIPDevice
3930
4031
Return the device associated with the array `A`.
4132
"""
42-
device(A::ROCArray) = A.buf[].device
33+
device(A::ROCArray) = A.buf[].mem.device
4334

4435
buftype(x::ROCArray) = buftype(typeof(x))
4536
buftype(::Type{<:ROCArray{<:Any, <:Any, B}}) where B = B # TODO check `@isdefined`?
@@ -73,8 +64,7 @@ AnyROCVecOrMat{T} = Union{AnyROCVector{T}, AnyROCMatrix{T}}
7364

7465
# type and dimensionality specified, accepting dims as tuples of Ints
7566
function ROCArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
76-
buf = Mem.HIPBuffer(prod(dims) * sizeof(T); stream=stream())
77-
ROCArray{T, N}(DataRef(_free_buf, buf), dims)
67+
ROCArray{T, N, Mem.HIPBuffer}(undef, dims)
7868
end
7969

8070
# buffer, type and dimensionality specified
@@ -98,7 +88,8 @@ ROCArray{T}(::UndefInitializer, dims::Vararg{Integer, N}) where {T, N} =
9888
# from Base arrays
9989
function ROCArray{T,N}(x::Array{T,N}, dims::Dims{N}) where {T,N}
10090
r = ROCArray{T,N}(undef, dims)
101-
Mem.upload!(r.buf[], pointer(x), sizeof(x); stream=stream())
91+
Mem.upload!(convert(Mem.AbstractAMDBuffer, r.buf[]),
92+
pointer(x), sizeof(x); stream=stream())
10293
return r
10394
end
10495

@@ -137,6 +128,8 @@ Base.convert(::Type{T}, x::T) where T <: ROCArray = x
137128

138129
## memory operations
139130

131+
# TODO rework, to pass pointers, instead of accessing .mem
132+
140133
function Base.copyto!(
141134
dest::Array{T}, d_offset::Integer,
142135
source::ROCArray{T}, s_offset::Integer, amount::Integer;
@@ -145,11 +138,11 @@ function Base.copyto!(
145138
amount == 0 && return dest
146139
@boundscheck checkbounds(dest, d_offset + amount - 1)
147140
@boundscheck checkbounds(source, s_offset + amount - 1)
148-
strm = stream()
149141
Mem.download!(
150142
pointer(dest, d_offset),
151-
Mem.view(source.buf[], (source.offset + s_offset - 1) * sizeof(T)),
152-
amount * sizeof(T); stream=strm, async)
143+
Mem.view(convert(Mem.AbstractAMDBuffer, source.buf[]),
144+
(source.offset + s_offset - 1) * sizeof(T)),
145+
amount * sizeof(T); stream=stream(), async)
153146
dest
154147
end
155148

@@ -161,7 +154,8 @@ function Base.copyto!(
161154
@boundscheck checkbounds(dest, d_offset + amount - 1)
162155
@boundscheck checkbounds(source, s_offset + amount - 1)
163156
Mem.upload!(
164-
Mem.view(dest.buf[], (dest.offset + d_offset - 1) * sizeof(T)),
157+
Mem.view(convert(Mem.AbstractAMDBuffer, dest.buf[]),
158+
(dest.offset + d_offset - 1) * sizeof(T)),
165159
pointer(source, s_offset), amount * sizeof(T); stream=stream())
166160
dest
167161
end
@@ -174,8 +168,10 @@ function Base.copyto!(
174168
@boundscheck checkbounds(dest, d_offset + amount - 1)
175169
@boundscheck checkbounds(source, s_offset + amount - 1)
176170
Mem.transfer!(
177-
Mem.view(dest.buf[], (dest.offset + d_offset - 1) * sizeof(T)),
178-
Mem.view(source.buf[], (source.offset + s_offset - 1) * sizeof(T)),
171+
Mem.view(convert(Mem.AbstractAMDBuffer, dest.buf[]),
172+
(dest.offset + d_offset - 1) * sizeof(T)),
173+
Mem.view(convert(Mem.AbstractAMDBuffer, source.buf[]),
174+
(source.offset + s_offset - 1) * sizeof(T)),
179175
amount * sizeof(T); stream=stream())
180176
dest
181177
end
@@ -197,7 +193,7 @@ function Base.unsafe_wrap(
197193
buf = lock ?
198194
Mem.HostBuffer(Ptr{Cvoid}(ptr), sz) :
199195
Mem.HIPBuffer(Ptr{Cvoid}(ptr), sz)
200-
ROCArray{T, N}(DataRef(_free_buf, buf), dims)
196+
ROCArray{T, N}(DataRef(pool_free, Managed(buf)), dims)
201197
end
202198

203199
Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims; kwargs...) where T =
@@ -237,12 +233,8 @@ Adapt.adapt_storage(::Float32Adaptor, xs::AbstractArray{Float16}) =
237233

238234
roc(xs) = adapt(Float32Adaptor(), xs)
239235

240-
function Base.unsafe_convert(::Type{Ptr{T}}, x::ROCArray{T}) where T
241-
# TODO have specialized convert function for buffers:
242-
# convert(hipPtr, buf) -> dev ptr
243-
tmp = typeof(x.buf[]) <: Mem.HIPBuffer ? x.buf[] : x.buf[].dev_ptr
244-
Base.unsafe_convert(Ptr{T}, tmp) + x.offset * sizeof(T)
245-
end
236+
Base.unsafe_convert(typ::Type{Ptr{T}}, x::ROCArray{T}) where T =
237+
convert(typ, x.buf[]) + x.offset * sizeof(T)
246238

247239
# some nice utilities
248240

@@ -280,13 +272,14 @@ function Base.resize!(A::ROCVector{T}, n::Integer) where T
280272

281273
copy_size = min(length(A), n) * sizeof(T)
282274
if copy_size > 0
283-
Mem.transfer!(new_buf, A.buf[], copy_size; stream=stream())
275+
Mem.transfer!(new_buf, convert(Mem.AbstractAMDBuffer, A.buf[]),
276+
copy_size; stream=stream())
284277
end
285278

286279
# Free old buffer.
287280
unsafe_free!(A)
288281

289-
A.buf = DataRef(_free_buf, new_buf)
282+
A.buf = DataRef(pool_free, Managed(new_buf))
290283
A.dims = (n,)
291284
A.offset = 0
292285
return A

src/blas/wrappers.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,18 +580,19 @@ end
580580
# helper function to get a device array of device pointers
581581
function device_batch(batch::Array{T}) where T <: ROCArray
582582
E = eltype(T)
583-
ROCArray([Base.unsafe_convert(Ptr{E}, arr.buf[]) for arr in batch])
583+
ROCArray([convert(Ptr{E}, arr.buf[]) for arr in batch])
584584
end
585585

586586
function device_batch(x::AnyROCArray{T, 3}) where T
587587
shift = size(x, 1) * size(x, 2) * sizeof(T)
588+
buf = convert(AMDGPU.Mem.AbstractAMDBuffer, x.buf[])
588589
ROCArray([
589-
Base.unsafe_convert(Ptr{T}, AMDGPU.Mem.view(x.buf[], shift * (i - 1)))
590+
convert(Ptr{T}, AMDGPU.Mem.view(buf, shift * (i - 1)))
590591
for i in 1:size(x, 3)])
591592
end
592593

593594
function device_batch(x::AnyROCArray{T, 3}, batch_count::Int) where T
594-
ptr = Base.unsafe_convert(Ptr{T}, x.buf[])
595+
ptr = convert(Ptr{T}, x.buf[])
595596
ROCArray([ptr for i in 1:batch_count])
596597
end
597598

src/compiler/dynamic_memory.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ function create_malloc_hostcall!()
1313
# Create host pinned memory and store HostCall in it.
1414
# It will be then accessed by kernels from kernel state.
1515
buf = Mem.HostBuffer(sizeof(holder.hc), HIP.hipHostAllocDefault)
16-
ptr = Base.unsafe_convert(
17-
Ptr{Device.HostCall{Ptr{Cvoid}, Tuple{Csize_t}}}, buf)
16+
ptr = convert(Ptr{Device.HostCall{Ptr{Cvoid}, Tuple{Csize_t}}}, buf)
1817
Base.unsafe_store!(ptr, holder.hc)
1918
return holder, buf
2019
end
@@ -34,8 +33,7 @@ function create_free_hostcall!()
3433
end
3534

3635
buf = Mem.HostBuffer(sizeof(holder.hc), HIP.hipHostAllocDefault)
37-
ptr = Base.unsafe_convert(
38-
Ptr{Device.HostCall{Nothing, Tuple{Ptr{Cvoid}}}}, buf)
36+
ptr = convert(Ptr{Device.HostCall{Nothing, Tuple{Ptr{Cvoid}}}}, buf)
3937
Base.unsafe_store!(ptr, holder.hc)
4038
return holder, buf
4139
end

src/compiler/exceptions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
# Check for exceptions on every synchronization.
1818
function check_exceptions()
1919
for (dev, buf) in _exception_flags
20-
ptr = Base.unsafe_convert(Ptr{Int}, buf)
20+
ptr = convert(Ptr{Int}, buf)
2121
flag = unsafe_load(ptr)
2222
if flag != 0
2323
unsafe_store!(ptr, 0)

src/compiler/output_context.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function create_output_context!(#= TODO mod::HIP.HIPModule =#)
1515

1616
# Pointer to HostCall to be read from device.
1717
buf = Mem.HostBuffer(sizeof(holder.hc), HIP.hipHostAllocDefault)
18-
ptr = Base.unsafe_convert(Ptr{Device.OUTPUT_CONTEXT_TYPE}, buf)
18+
ptr = convert(Ptr{Device.OUTPUT_CONTEXT_TYPE}, buf)
1919
Base.unsafe_store!(ptr, holder.hc)
2020
return holder, buf
2121
end
@@ -48,8 +48,7 @@ function create_printf_output_context!()
4848
end
4949
# Pointer to HostCall to be read from device.
5050
buf = Mem.HostBuffer(sizeof(holder.hc), HIP.hipHostAllocDefault)
51-
ptr = Base.unsafe_convert(
52-
Ptr{Device.PRINTF_OUTPUT_CONTEXT_TYPE}, buf)
51+
ptr = convert(Ptr{Device.PRINTF_OUTPUT_CONTEXT_TYPE}, buf)
5352
Base.unsafe_store!(ptr, holder.hc)
5453
return holder, buf
5554
end

src/device/gcn/hostcall.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function HostCall(
5454
buf_len = max(sizeof(UInt64), buf_len) # make room for return buffer pointer
5555
buf = Mem.HostBuffer(buf_len, AMDGPU.HIP.hipHostAllocDefault)
5656

57-
buf_ptr = LLVMPtr{UInt8, AS.Global}(Base.unsafe_convert(Ptr{UInt8}, buf))
57+
buf_ptr = LLVMPtr{UInt8, AS.Global}(convert(Ptr{UInt8}, buf))
5858
host_signal_store!(HSA.Signal(signal_handle), READY_SENTINEL)
5959
HostCall{RT, AT}(signal_handle, buf_ptr, buf_len)
6060
end
@@ -144,7 +144,7 @@ function HostCallHolder(
144144

145145
ret_ref = Ref{rettype}(ret)
146146
GC.@preserve ret_ref begin
147-
ret_ptr = Base.unsafe_convert(Ptr{Cvoid}, ret_buf)
147+
ret_ptr = convert(Ptr{Cvoid}, ret_buf)
148148
if sizeof(ret) > 0
149149
src_ptr = reinterpret(Ptr{Cvoid},
150150
Base.unsafe_convert(Ptr{rettype}, ret_ref))

src/exception_handler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ end
7171

7272
function has_exception(dev::HIPDevice)::Bool
7373
ex = exception_holder(dev)
74-
ptr = Base.unsafe_convert(Ptr{Int}, ex.exception_flag)
74+
ptr = convert(Ptr{Int}, ex.exception_flag)
7575
unsafe_load(ptr) != 0
7676
end
7777

7878
function reset_exception_holder!(dev::HIPDevice)
7979
ex = exception_holder(dev)
80-
ptr = Base.unsafe_convert(Ptr{Int}, ex.exception_flag)
80+
ptr = convert(Ptr{Int}, ex.exception_flag)
8181
unsafe_store!(ptr, 0)
8282

8383
fill!(ex.buffers_counter, 0)

src/gpuarrays.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ function Base.convert(
5858
::Type{ROCDeviceArray{T, N, AS.Global}}, a::ROCArray{T, N},
5959
) where {T, N}
6060
# If HostBuffer, use device pointer.
61-
ptr = Base.unsafe_convert(Ptr{T},
62-
typeof(a.buf[]) <: Mem.HIPBuffer ? a.buf[] : a.buf[].dev_ptr)
61+
buf = convert(Mem.AbstractAMDBuffer, a.buf[])
62+
ptr = convert(Ptr{T}, typeof(buf) <: Mem.HIPBuffer ?
63+
buf : buf.dev_ptr)
6364
llvm_ptr = AMDGPU.LLVMPtr{T,AS.Global}(ptr + a.offset * sizeof(T))
6465
ROCDeviceArray{T, N, AS.Global}(a.dims, llvm_ptr)
6566
end

src/hip/stream.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ mutable struct HIPStream
66
priority::Symbol
77
device::HIPDevice
88
ctx::HIPContext
9+
10+
Base.@atomic valid::Bool
911
end
1012

1113
"""
@@ -24,16 +26,19 @@ function HIPStream(priority::Symbol = :normal)
2426
stream_ref = Ref{hipStream_t}()
2527
hipStreamCreateWithPriority(stream_ref, 0, priority_int) |> check
2628
d = device()
27-
stream = HIPStream(stream_ref[], priority, d, HIPContext(d))
29+
stream = HIPStream(stream_ref[], priority, d, HIPContext(d), true)
2830
finalizer(stream) do s
2931
AMDGPU.context!(s.ctx) do
3032
hipStreamDestroy(s.stream) |> check
3133
end
34+
Base.@atomic s.valid = false
3235
end
3336
return stream
3437
end
3538

36-
default_stream() = HIPStream(C_NULL, :normal, device(), HIPContext())
39+
isvalid(s::HIPStream) = s.valid
40+
41+
default_stream() = HIPStream(C_NULL, :normal, device(), HIPContext(), true)
3742

3843
"""
3944
HIPStream(stream::hipStream_t)
@@ -43,7 +48,7 @@ Device is the default device that's currently in use.
4348
"""
4449
function HIPStream(stream::hipStream_t)
4550
d = device()
46-
HIPStream(stream, priority(stream), d, HIPContext(d))
51+
HIPStream(stream, priority(stream), d, HIPContext(d), true)
4752
end
4853

4954
function isdone(stream::HIPStream)

0 commit comments

Comments
 (0)