diff --git a/Project.toml b/Project.toml index ee6a92b..e1f69a7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "1.7.0" +version = "1.7.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/testers.jl b/src/testers.jl index 31168fc..0a28ae6 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -227,7 +227,7 @@ function test_rrule( end if check_thunked_output_tangent - test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:") + test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:"; isapprox_kwargs...) check_inferred && _test_inferred(pullback, @thunk(ȳ)) end end # top-level testset diff --git a/test/testers.jl b/test/testers.jl index 00d542a..16230bd 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -683,7 +683,10 @@ end function ChainRulesCore.rrule(::typeof(my_id), x) my_id_pb(ȳ) = (NoTangent(), ȳ) function my_id_pb(ȳ::AbstractThunk) - precision = rand() > 0.5 ? Float64 : Float32 + # We use a condition that always evaluates to true to avoid issues with tolerances + # (see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/247) + # The function is type unstable for `Float64` inputs nevertheless + precision = rand() >= 0.0 ? Float64 : Float32 return (NoTangent(), precision(unthunk(ȳ))) end return x, my_id_pb