Skip to content

Commit 698fd2a

Browse files
authored
Fix strided matmul & minor cleanup (#677)
1 parent 3d17779 commit 698fd2a

File tree

17 files changed

+67
-105
lines changed

17 files changed

+67
-105
lines changed

src/AMDGPU.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ include("utils.jl")
7373

7474
include(joinpath("hsa", "HSA.jl"))
7575
include(joinpath("hip", "HIP.jl"))
76-
import .HIP: HIPContext, HIPDevice, HIPStream
76+
77+
using .HIP
78+
using .HIP: HIPContext, HIPDevice, HIPStream
7779
export HIPContext, HIPDevice, HIPStream
7880

7981
include("cache.jl")

src/array.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,19 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
55

66
function ROCArray{T, N, B}(
77
::UndefInitializer, dims::Dims{N},
8-
) where {T, N, B <: Union{Mem.HIPBuffer, Mem.HostBuffer}}
8+
) where {T, N, B <: Mem.AbstractAMDBuffer}
99
@assert isbitstype(T) "ROCArray only supports bits types"
1010
data = DataRef(pool_free, pool_alloc(B, prod(dims) * sizeof(T)))
1111
xs = new{T, N, B}(data, dims, 0)
12-
finalizer(unsafe_free!, xs)
13-
return xs
12+
return finalizer(unsafe_free!, xs)
1413
end
1514

1615
function ROCArray{T, N}(
1716
buf::DataRef{Managed{B}}, dims::Dims{N}; offset::Integer = 0,
18-
) where {T, N, B <: Union{Mem.HIPBuffer, Mem.HostBuffer}}
17+
) where {T, N, B <: Mem.AbstractAMDBuffer}
1918
@assert isbitstype(T) "ROCArray only supports bits types"
2019
xs = new{T, N, B}(buf, dims, offset)
21-
finalizer(unsafe_free!, xs)
22-
return xs
20+
return finalizer(unsafe_free!, xs)
2321
end
2422
end
2523

@@ -117,11 +115,8 @@ ROCArray(A::AbstractArray{T,N}) where {T,N} = ROCArray{T,N}(A)
117115
ROCArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = ROCArray{T,N}(xs)
118116
(::Type{ROCArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = ROCArray{S,N}(x)
119117

120-
# idempotency
121118
ROCArray{T,N}(xs::ROCArray{T,N}) where {T,N} = copy(xs)
122119

123-
## conversions
124-
125120
Base.convert(::Type{T}, x::T) where T <: ROCArray = x
126121

127122
## memory operations
@@ -176,7 +171,6 @@ function Base.copyto!(
176171
dest
177172
end
178173

179-
# TODO: Workaround for hanging copy() broadcast kernel
180174
function Base.copy(X::ROCArray{T}) where T
181175
Xnew = ROCArray{T}(undef, size(X))
182176
copyto!(Xnew, 1, X, 1, length(X))
@@ -220,10 +214,8 @@ struct Float32Adaptor end
220214

221215
Adapt.adapt_storage(::Float32Adaptor, xs::AbstractArray) =
222216
isbits(xs) ? xs : convert(ROCArray, xs)
223-
224217
Adapt.adapt_storage(::Float32Adaptor, xs::AbstractArray{<:AbstractFloat}) =
225218
isbits(xs) ? xs : convert(ROCArray{Float32}, xs)
226-
227219
Adapt.adapt_storage(::Float32Adaptor, xs::AbstractArray{<:Complex{<:AbstractFloat}}) =
228220
isbits(xs) ? xs : convert(ROCArray{ComplexF32}, xs)
229221

src/blas/wrappers.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -536,17 +536,18 @@ end
536536

537537
# Level 3
538538
## (GE) general matrix-matrix multiplication
539-
for (fname, elty) in
540-
((:rocblas_dgemm,:Float64),
541-
(:rocblas_sgemm,:Float32),
542-
(:rocblas_hgemm,:Float16),
543-
(:rocblas_zgemm,:ComplexF64),
544-
(:rocblas_cgemm,:ComplexF32))
539+
for (fname, elty) in (
540+
(:rocblas_dgemm,:Float64),
541+
(:rocblas_sgemm,:Float32),
542+
(:rocblas_hgemm,:Float16),
543+
(:rocblas_zgemm,:ComplexF64),
544+
(:rocblas_cgemm,:ComplexF32),
545+
)
545546
@eval begin
546547
function gemm!(
547548
transA::Char, transB::Char, alpha::($elty),
548-
A::ROCVecOrMat{$elty}, B::ROCVecOrMat{$elty}, beta::($elty),
549-
C::ROCVecOrMat{$elty},
549+
A::StridedROCVecOrMat{$elty}, B::StridedROCVecOrMat{$elty}, beta::($elty),
550+
C::StridedROCVecOrMat{$elty},
550551
)
551552
m = size(A, transA == 'N' ? 1 : 2)
552553
k = size(A, transA == 'N' ? 2 : 1)

src/dnn/batchnorm.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,9 @@ function derive_beta_gamma_descriptors(
104104
dtype, dims, stride = unpack(handle, ndims(handle))
105105

106106
bndesc = TensorDescriptor(handle, dtype)
107-
finalizer(bndesc) do d_
108-
miopenDestroyTensorDescriptor(d_.handle)
107+
return finalizer(bndesc) do d
108+
miopenDestroyTensorDescriptor(d.handle)
109109
end
110-
bndesc
111110
end
112111

113112
# Unsqueeze dimensions at the beginning:

src/dnn/descriptors.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ function TensorDescriptor(
3535
handle = handle_ref[]
3636
miopenSetTensorDescriptor(handle, dtype, dims, sizes, strides)
3737
d = TensorDescriptor(handle, dtype)
38-
39-
finalizer(d) do d_
40-
miopenDestroyTensorDescriptor(d_.handle)
38+
return finalizer(d) do d
39+
miopenDestroyTensorDescriptor(d.handle)
4140
end
42-
d
4341
end
4442

4543
function TensorDescriptor(x::ROCArray{T}) where T
@@ -96,11 +94,9 @@ function ConvolutionDescriptor(
9694
handle, n_dims, padding, stride, dilation, miopenConvolution)
9795
miopenSetConvolutionGroupCount(handle, groups)
9896
d = ConvolutionDescriptor(handle)
99-
100-
finalizer(d) do d_
101-
miopenDestroyConvolutionDescriptor(d_.handle)
97+
return finalizer(d) do d
98+
miopenDestroyConvolutionDescriptor(d.handle)
10299
end
103-
d
104100
end
105101

106102
"""
@@ -150,11 +146,9 @@ function PoolingDescriptor(
150146
handle, mode, n_dims, dims, padding, stride)
151147
miopenSetPoolingIndexType(handle, miopenIndexUint32)
152148
d = PoolingDescriptor(handle)
153-
154-
finalizer(d) do d_
155-
miopenDestroyPoolingDescriptor(d_.handle)
149+
return finalizer(d) do d
150+
miopenDestroyPoolingDescriptor(d.handle)
156151
end
157-
d
158152
end
159153

160154
"""
@@ -192,10 +186,9 @@ function ActivationDescriptor()
192186
miopenCreateActivationDescriptor(handle_ref)
193187
handle = handle_ref[]
194188
d = ActivationDescriptor(handle)
195-
finalizer(d) do d_
196-
miopenDestroyActivationDescriptor(d_.handle)
189+
return finalizer(d) do d
190+
miopenDestroyActivationDescriptor(d.handle)
197191
end
198-
d
199192
end
200193

201194
function set!(

src/fft/fft.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ mutable struct cROCFFTPlan{T, K, inplace, N} <: ROCFFTPlan{T, K, inplace}
5151
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
5252
end
5353
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region)
54-
finalizer(AMDGPU.unsafe_free!, p)
55-
p
54+
return finalizer(AMDGPU.unsafe_free!, p)
5655
end
5756
end
5857

@@ -81,8 +80,7 @@ mutable struct rROCFFTPlan{T,K,inplace,N} <: ROCFFTPlan{T,K,inplace}
8180
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
8281
end
8382
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region)
84-
finalizer(unsafe_free!, p)
85-
p
83+
return finalizer(unsafe_free!, p)
8684
end
8785
end
8886

src/highlevel.jl

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
"""
2-
devices()
3-
4-
Get list of all devices.
5-
"""
6-
devices() = HIP.devices()
7-
81
"""
92
device_id() -> Int
103
device_id(device::HIPDevice) -> Int
@@ -22,18 +15,12 @@ Sets the current device to `AMDGPU.devices()[idx]`. See
2215
"""
2316
device_id!(idx::Integer) = device!(devices()[idx])
2417

25-
# Contexts
26-
2718
function device(context::HIPContext)
2819
return HIP.context!(context) do
2920
HIP.device()
3021
end
3122
end
3223

33-
# Streams.
34-
35-
default_stream() = HIP.default_stream()
36-
3724
device(stream::HIPStream) = stream.device
3825
device(idx::Integer) = devices()[idx]
3926

@@ -91,12 +78,6 @@ macro sync(ex)
9178
end
9279
end
9380

94-
"""
95-
Blocks until all kernels on all streams have completed.
96-
Uses currently active device.
97-
"""
98-
device_synchronize() = HIP.device_synchronize()
99-
10081
"""
10182
rocconvert(x)
10283
@@ -184,11 +165,8 @@ macro roc(ex...)
184165
end)
185166
end
186167

187-
function launch_configuration(
188-
kern::Runtime.HIPKernel; shmem::Integer = 0, max_block_size::Integer = 0,
189-
)
168+
launch_configuration(kern::Runtime.HIPKernel; shmem::Integer = 0, max_block_size::Integer = 0) =
190169
HIP.launch_configuration(kern.fun; shmem, max_block_size)
191-
end
192170

193171
"""
194172
@elapsed ex

src/hip/HIP.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module HIP
22
export HIPError
3+
export devices, device_synchronize, default_stream
34

45
using CEnum
56

@@ -74,6 +75,10 @@ include("event.jl")
7475
include("pool.jl")
7576
include("module.jl")
7677

78+
"""
79+
Blocks until all kernels on all streams have completed.
80+
Uses currently active device.
81+
"""
7782
function device_synchronize()
7883
AMDGPU.maybe_collect(; blocking=true)
7984
hipDeviceSynchronize()

src/hip/device.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ function ndevices()
104104
count_ref[]
105105
end
106106

107+
"""
108+
devices()
109+
110+
Get list of all devices.
111+
"""
107112
function devices()
108113
isempty(ALL_DEVICES) || return copy(ALL_DEVICES)
109114

src/hip/event.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,20 @@ end
4444
wait(event::HIPEvent) = hipEventSynchronize(event)
4545

4646
function synchronize(event::HIPEvent)
47-
if !non_blocking_synchronize(event)
48-
AMDGPU.maybe_collect(; blocking=true)
49-
end
47+
non_blocking_synchronize(event) || AMDGPU.maybe_collect(; blocking=true)
5048
wait(event)
5149
return
5250
end
5351

5452
function HIPEvent(stream::hipStream_t; do_record::Bool = true, timing=false)
5553
event_ref = Ref{hipEvent_t}()
56-
if !timing
54+
timing ?
55+
hipEventCreate(event_ref) :
5756
hipEventCreateWithFlags(event_ref, hipEventDisableTiming)
58-
else
59-
hipEventCreate(event_ref)
60-
end
6157
event = HIPEvent(event_ref[], stream)
6258
do_record && record(event)
6359

64-
finalizer(hipEventDestroy, event)
65-
return event
60+
return finalizer(hipEventDestroy, event)
6661
end
6762
HIPEvent(stream::HIPStream; kwargs...) = HIPEvent(stream.stream; kwargs...)
6863

@@ -76,5 +71,5 @@ See also [`@elapsed`](@ref).
7671
function elapsed(start::HIPEvent, stop::HIPEvent)
7772
time_ref = Ref{Cfloat}()
7873
hipEventElapsedTime(time_ref, start, stop)
79-
return time_ref[]/1000
74+
return time_ref[] / 1000
8075
end

src/hip/module.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ mutable struct HIPModule
88
mod_ref = Ref{hipModule_t}()
99
hipModuleLoadData(mod_ref, data)
1010
mod = new(mod_ref[])
11-
12-
finalizer(hipModuleUnload, mod)
13-
mod
11+
return finalizer(hipModuleUnload, mod)
1412
end
1513
end
1614

src/hip/stream.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ function HIPStream(priority::Symbol = :normal)
2727
hipStreamCreateWithPriority(stream_ref, 0, priority_int)
2828
d = device()
2929
stream = HIPStream(stream_ref[], priority, d, HIPContext(d), true)
30-
finalizer(stream) do s
30+
return finalizer(stream) do s
3131
AMDGPU.context!(s.ctx) do
3232
hipStreamDestroy(s.stream)
3333
end
3434
Base.@atomic s.valid = false
3535
end
36-
return stream
3736
end
3837

3938
isvalid(s::HIPStream) = s.valid

src/rand/random.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ mutable struct RNG <: Random.AbstractRNG
88
handle = Ref{rocrand_generator}()
99
rocrand_create_generator(handle, typ)
1010
obj = new(handle[], typ)
11-
finalizer(unsafe_destroy!, obj)
12-
return obj
11+
return finalizer(unsafe_destroy!, obj)
1312
end
1413
end
1514

src/runtime/Runtime.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ module Mem
2424
using Preferences
2525

2626
import AMDGPU
27-
import AMDGPU: HIP, HSA, Runtime
27+
import AMDGPU: HIP, Runtime
2828
import .HIP: HIPDevice, HIPContext
2929
import .Runtime: ROCDim, ROCDim3
3030

src/runtime/error.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,3 @@ function check(result::HSA.Status)
3535
throw(HSAError(result))
3636
end
3737
end
38-

0 commit comments

Comments
 (0)