Skip to content

Fast path onehotbatch(::Vector{Int}, ::UnitRange) #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "OneHotArrays"
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
version = "0.2.1"
version = "0.2.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
10 changes: 10 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ function _onehotbatch(data, labels, default)
return OneHotArray(indices, length(labels))
end

function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
# lo, hi = extrema(data) # fails on Julia 1.6
lo, hi = minimum(data), maximum(data)
lo < first(labels) && error("Value $lo not found in labels")
hi > last(labels) && error("Value $hi not found in labels")
offset = 1 - first(labels)
indices = UInt32.(data .+ offset)
return OneHotArray(indices, length(labels))
Comment on lines +103 to +110
Copy link
Member Author

@mcabbott mcabbott Dec 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the bounds checking here is quite expensive, especially on GPU arrays where I think each of minimum & maximum forces synchronisation:

julia> let ci = cu(rand(1:99, 100))
         @btime CUDA.@sync onehotbatch($ci, 1:99)
         @btime CUDA.@sync OneHotMatrix($ci, 99)
       end;
  100.993 μs (86 allocations: 4.02 KiB)
  2.803 μs (0 allocations: 0 bytes)

julia> let ci = cu(rand(1:99, 100))
         @btime CUDA.@sync maximum($ci), minimum($ci)
         @btime CUDA.@sync extrema($ci)
         @btime CUDA.@sync map($ci) do i
            0<i<100 || error("bad index")
            UInt32(i+0)
           end
        end;
  71.448 μs (58 allocations: 2.91 KiB)
  38.094 μs (29 allocations: 1.47 KiB)
  18.543 μs (30 allocations: 1.14 KiB)
  
julia> let ci = cu(rand(1:99, 100))  # without explicit CUDA.@sync 
         @btime extrema($ci)
         @btime OneHotMatrix($ci, 99)       # async, which is good
         @btime OneHotMatrix(map($ci) do i  # unfortunately not?
            0<i<100 || error("bad index")
            UInt32(i+0)
           end, 99)
        end;
  35.544 μs (29 allocations: 1.47 KiB)
  6.527 ns (0 allocations: 0 bytes)
  10.619 μs (30 allocations: 1.14 KiB)

Moving the check inside the broadcast is faster, at the cost of more obscure errors. Maybe that's ok? Still not fully async.

julia> map(cu(rand(1:199, 100))) do i
                   0<i<100 || error("bad index")
                   UInt32(i+0)
                  end
ERROR: a exception was thrown during kernel execution.
       Run Julia on debug level 2 for device stack traces.
ERROR: a exception was thrown during kernel execution.
       Run Julia on debug level 2 for device stack traces.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is rather obscure indeed. What if we wrap the map inside a try-catch and raise a proper error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can tell the GPU not to wait? This doesn't work but perhaps something similar does:

julia> let i = rand(1:99, 100)
         @btime maximum($i)<100 || error("outside")
         @btime @async maximum($i)<100 || error("outside")
       end;
  58.407 ns (0 allocations: 0 bytes)
  759.747 ns (5 allocations: 496 bytes)

julia> let ci = cu(rand(1:99, 100))
         @btime maximum($ci)<100 || error("outside")
         @btime @async maximum($ci)<100 || error("outside")
       end;
  35.134 μs (29 allocations: 1.45 KiB)
  # hangs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ideal solution would be something like JuliaGPU/CUDA.jl#1140. If we had a way to write kernels, another idea would be to create an ad-hoc in kernel which flips a one-element bool array to true if it finds a matching element.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like the right thing. Perhaps rather than owning a kernel, this package could call checkbounds(out, inds, 1) or whatever -- that's essentially the same operation.

I wondered what gather did, and it turns out there is no check:

julia> NNlib.gather([1,20,300,4000] |> cu, [2,4,2,99] |> cu)
4-element CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}:
   20
 4000
   20
    0

julia> NNlib.gather([1,20,300,4000], [2,4,2,99])
ERROR: BoundsError: attempt to access 4-element Vector{Int64} at index [99]

The PR to add one FluxML/NNlibCUDA.jl#51 has many benchmarks... perhaps also 10s of μs.

end

"""
onecold(y::AbstractArray, labels = 1:size(y,1))

Expand Down
10 changes: 10 additions & 0 deletions test/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ end
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
end

@testset "onehotbatch(::CuArray, ::UnitRange)" begin
y1 = onehotbatch([1, 3, 0, 2], 0:9) |> cu
y2 = onehotbatch([1, 3, 0, 2] |> cu, 0:9)
@test y1.indices == y2.indices
@test_broken y1 == y2

@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10)
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2)
end

@testset "onecold gpu" begin
y = onehotbatch(ones(3), 1:10) |> cu;
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
Expand Down
6 changes: 6 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
@test onecold(onehot(-0.0, floats)) == 2 # as it uses isequal
@test onecold(onehot(Inf, floats)) == 5

# UnitRange fast path
@test onehotbatch([1,3,0,4], 0:4) == onehotbatch([1,3,0,4], Tuple(0:4))
@test onehotbatch([2 3 7 4], 2:7) == onehotbatch([2 3 7 4], Tuple(2:7))
@test_throws Exception onehotbatch([2, -1], 0:4)
@test_throws Exception onehotbatch([2, 5], 0:4)

# inferrabiltiy tests
@test @inferred(onehot(20, 10:10:30)) == [false, true, false]
@test @inferred(onehot(40, (10,20,30), 20)) == [false, true, false]
Expand Down