From e007533d568125918198e14cc4ad1018149c3995 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 26 Dec 2022 15:07:33 -0500 Subject: [PATCH 1/4] add a fast path --- src/onehot.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/onehot.jl b/src/onehot.jl index c225fc4..2bda431 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -100,6 +100,12 @@ function _onehotbatch(data, labels, default) return OneHotArray(indices, length(labels)) end +function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) + offset = 1 - first(labels) + indices = UInt32.(data .+ offset) + return OneHotArray(indices, length(labels)) +end + """ onecold(y::AbstractArray, labels = 1:size(y,1)) From 6c432cc66ebf49c18df8c11d42915b59197d40d1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 26 Dec 2022 15:25:52 -0500 Subject: [PATCH 2/4] add an error check --- src/onehot.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/onehot.jl b/src/onehot.jl index 2bda431..b654504 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -103,6 +103,7 @@ end function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) offset = 1 - first(labels) indices = UInt32.(data .+ offset) + maximum(indices) > last(labels) + offset && error("Largest value not found in labels") return OneHotArray(indices, length(labels)) end From 6809fd9ad3a85c846d8be07588428c9113b3d920 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Dec 2022 12:34:48 -0500 Subject: [PATCH 3/4] fixup, add tests --- Project.toml | 2 +- src/onehot.jl | 4 +++- test/gpu.jl | 10 ++++++++++ test/onehot.jl | 6 ++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 29a5818..94da4fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "OneHotArrays" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.1" +version = "0.2.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/onehot.jl b/src/onehot.jl index b654504..7464b08 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -101,9 +101,11 @@ function _onehotbatch(data, labels, default) end function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) + lo, hi = extrema(data) + lo < first(labels) && error("Value $lo not found in labels") + hi > last(labels) && error("Value $hi not found in labels") offset = 1 - first(labels) indices = UInt32.(data .+ offset) - maximum(indices) > last(labels) + offset && error("Largest value not found in labels") return OneHotArray(indices, length(labels)) end diff --git a/test/gpu.jl b/test/gpu.jl index 13c208c..cd04815 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -26,6 +26,16 @@ end @test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote? end +@testset "onehotbatch(::CuArray, ::UnitRange)" begin + y1 = onehotbatch([1, 3, 0, 2], 0:9) |> cu + y2 = onehotbatch([1, 3, 0, 2] |> cu, 0:9) + @test y1.indices == y2.indices + @test_broken y1 == y2 + + @test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10) + @test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2) +end + @testset "onecold gpu" begin y = onehotbatch(ones(3), 1:10) |> cu; l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] diff --git a/test/onehot.jl b/test/onehot.jl index 0628230..fffac19 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -27,6 +27,12 @@ @test onecold(onehot(-0.0, floats)) == 2 # as it uses isequal @test onecold(onehot(Inf, floats)) == 5 + # UnitRange fast path + @test onehotbatch([1,3,0,4], 0:4) == onehotbatch([1,3,0,4], Tuple(0:4)) + @test onehotbatch([2 3 7 4], 2:7) == onehotbatch([2 3 7 4], Tuple(2:7)) + @test_throws Exception onehotbatch([2, -1], 0:4) + @test_throws Exception onehotbatch([2, 5], 0:4) + # inferrabiltiy tests @test @inferred(onehot(20, 10:10:30)) == [false, true, false] @test @inferred(onehot(40, (10,20,30), 20)) == [false, true, false] From 7c1238f5bf074f00b39e5306afc45b9c45d2dc4a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Dec 2022 12:51:54 -0500 Subject: [PATCH 4/4] fix 1.6 --- src/onehot.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 7464b08..ca2efa5 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -101,7 +101,8 @@ function _onehotbatch(data, labels, default) end function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) - lo, hi = extrema(data) + # lo, hi = extrema(data) # fails on Julia 1.6 + lo, hi = minimum(data), maximum(data) lo < first(labels) && error("Value $lo not found in labels") hi > last(labels) && error("Value $hi not found in labels") offset = 1 - first(labels)