Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 9c716b5

Browse files
committed
fix: check was accidentally broken
1 parent 10488cd commit 9c716b5

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "LuxTestUtils"
22
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
33
authors = ["Avik Pal <avikpal@mit.edu>"]
4-
version = "1.3.0"
4+
version = "1.3.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1011
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
@@ -21,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2122

2223
[compat]
2324
ADTypes = "1.8.1"
25+
ArrayInterface = "7.9"
2426
ChainRulesCore = "1.24.0"
2527
ComponentArrays = "0.15.14"
2628
DispatchDoctor = "0.4.12"

src/LuxTestUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module LuxTestUtils
22

3+
using ArrayInterface: ArrayInterface
34
using ComponentArrays: ComponentArray, getdata, getaxes
45
using DispatchDoctor: allow_unstable
56
using Functors: Functors

src/autodiff.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,10 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[],
172172
@testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end]
173173
local_test_expr = :([$(nameof(typeof(backend)))] - $(test_expr))
174174

175-
result = if backend in skip_backends
175+
result = if check_ad_backend_in(backend, skip_backends)
176176
Broken(:skipped, local_test_expr)
177177
elseif (soft_fail isa Bool && soft_fail) ||
178-
(soft_fail isa Vector && backend in soft_fail)
178+
(soft_fail isa Vector && check_ad_backend_in(backend, soft_fail))
179179
try
180180
∂args = allow_unstable() do
181181
return gradient(f, backend, args...)
@@ -189,7 +189,7 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[],
189189
catch
190190
Broken(:test, local_test_expr)
191191
end
192-
elseif backend in broken_backends
192+
elseif check_ad_backend_in(backend, broken_backends)
193193
try
194194
∂args = allow_unstable() do
195195
return gradient(f, backend, args...)

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,9 @@ function reorder_macro_kw_params(exs)
123123
end
124124
return Tuple(exs)
125125
end
126+
127+
function check_ad_backend_in(backend, backends)
128+
backends_type = map(ArrayInterface.parameterless_type typeof, backends)
129+
backend_type = ArrayInterface.parameterless_type(typeof(backend))
130+
return backend_type in backends_type
131+
end

0 commit comments

Comments
 (0)